Skip to content

Commit

Permalink
Merge pull request #10 from Neuro-Flex/kasinadhsarma/fix-edge-cases
Browse files Browse the repository at this point in the history
Fix edge cases and deprecated usages in consciousness and memory models
  • Loading branch information
kasinadhsarma authored Dec 25, 2024
2 parents 5575ade + 06fdbc6 commit f9ca348
Show file tree
Hide file tree
Showing 19 changed files with 835 additions and 136 deletions.
101 changes: 44 additions & 57 deletions models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,78 +8,58 @@ class ConsciousnessAttention(nn.Module):
Multi-head attention mechanism for consciousness modeling based on Global Workspace Theory.
Implements scaled dot-product attention with consciousness-aware broadcasting.
"""
def __init__(self, num_heads: int, head_dim: int, dropout_rate: float = 0.1, attention_dropout_rate: float = 0.1):
def __init__(self, num_heads: int, head_dim: int, dropout_rate: float = 0.1):
super().__init__()
self.hidden_dim = num_heads * head_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout_rate = dropout_rate
self.attention_dropout_rate = attention_dropout_rate
self.depth = num_heads * head_dim
self.scale = head_dim ** -0.5

# Linear projections
self.query = nn.Linear(self.depth, self.depth)
self.key = nn.Linear(self.depth, self.depth)
self.value = nn.Linear(self.depth, self.depth)
self.output_projection = nn.Linear(self.depth, self.depth)

# Dropouts
self.attn_dropout = nn.Dropout(attention_dropout_rate)
self.query = nn.Linear(self.hidden_dim, self.hidden_dim)
self.key = nn.Linear(self.hidden_dim, self.hidden_dim)
self.value = nn.Linear(self.hidden_dim, self.hidden_dim)

# Dropout layers
self.attn_dropout = nn.Dropout(dropout_rate)
self.output_dropout = nn.Dropout(dropout_rate)

def forward(self, inputs_q: torch.Tensor, inputs_kv: torch.Tensor,
mask: Optional[torch.Tensor] = None,
training: bool = True,
deterministic: Optional[bool] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of consciousness attention.
Args:
inputs_q: Query inputs
inputs_kv: Key-value inputs
mask: Optional attention mask
training: Whether in training mode (controls dropout)
deterministic: Optional override for training mode
"""
batch_size = inputs_q.size(0)

# Use deterministic to override training mode if provided
is_training = training if deterministic is None else not deterministic
def forward(self, query, key_value, mask=None, training=None):
"""Forward pass of consciousness attention mechanism."""
# Input validation
if query.size(0) == 0 or query.size(1) == 0 or query.size(2) == 0:
raise ValueError("Query tensor cannot be empty")
if key_value.size(0) == 0 or key_value.size(1) == 0 or key_value.size(2) == 0:
raise ValueError("Key/Value tensor cannot be empty")

# Validate input dimensions
if query.size(-1) != self.hidden_dim or key_value.size(-1) != self.hidden_dim:
raise ValueError(f"Expected input dimension {self.hidden_dim}, got query: {query.size(-1)}, key/value: {key_value.size(-1)}")

batch_size = query.size(0)

# Linear projections
query = self.query(inputs_q)
key = self.key(inputs_kv)
value = self.value(inputs_kv)

# Reshape for multi-head attention
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Linear projections and reshape for multi-head attention
q = self.query(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.key(key_value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.value(key_value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

# Scaled dot-product attention
depth_scaling = float(self.head_dim) ** -0.5
attention_logits = torch.matmul(query, key.transpose(-2, -1)) * depth_scaling
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

if mask is not None:
mask = mask.unsqueeze(1).unsqueeze(2)
attention_logits = attention_logits.masked_fill(~mask, float('-inf'))
# Expand mask for multiple heads
expanded_mask = mask.unsqueeze(1).unsqueeze(2)
scores = scores.masked_fill(~expanded_mask, float('-inf'))

attention_weights = F.softmax(attention_logits, dim=-1)

if is_training:
attention_weights = self.attn_dropout(attention_weights)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.attn_dropout(attention_weights)

# Compute attention output
attention_output = torch.matmul(attention_weights, value)
# Apply attention weights to values
output = torch.matmul(attention_weights, v)

# Reshape and project output
attention_output = attention_output.transpose(1, 2).contiguous()
attention_output = attention_output.view(batch_size, -1, self.depth)
output = self.output_projection(attention_output)

if is_training:
output = self.output_dropout(output)

# Residual connection
output = output + inputs_q
# Reshape back
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
output = self.output_dropout(output)

return output, attention_weights

Expand Down Expand Up @@ -118,6 +98,13 @@ def forward(self, inputs: torch.Tensor,
memory_state: Optional[torch.Tensor] = None,
deterministic: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with optional deterministic mode."""
# Input validation
if inputs.size(0) == 0 or inputs.size(1) == 0 or inputs.size(2) == 0:
raise ValueError("Input tensor cannot be empty")

if inputs.size(-1) != self.hidden_dim:
raise ValueError(f"Expected input dimension {self.hidden_dim}, got {inputs.size(-1)}")

# Layer normalization and attention
x = self.layer_norm1(inputs)
attended_output, attention_weights = self.attention(
Expand Down
19 changes: 19 additions & 0 deletions models/attention/attention_mechanisms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class ConsciousnessAttention(nn.Module):
def forward(self, query, key=None, value=None, mask=None):
# Validate inputs
if query.size(0) == 0 or query.size(1) == 0:
raise ValueError("Empty input tensor")
if torch.isnan(query).any():
raise ValueError("Input contains NaN values")

# ...existing code...

class GlobalWorkspace(nn.Module):
def forward(self, x):
# Validate input
if x.size(0) == 0 or x.size(1) == 0:
raise ValueError("Empty input tensor")
if torch.isnan(x).any():
raise ValueError("Input contains NaN values")

# ...existing code...
19 changes: 19 additions & 0 deletions models/attention_mechanisms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class ConsciousnessAttention(nn.Module):
def forward(self, x, mask=None):
# Input validation
if x.size(0) == 0 or x.size(1) == 0:
raise ValueError("Empty input tensor")
if torch.isnan(x).any():
raise ValueError("Input contains NaN values")

# ...existing code...

class GlobalWorkspace(nn.Module):
def forward(self, inputs):
# Input validation
if inputs.size(0) == 0 or inputs.size(1) == 0:
raise ValueError("Empty input tensor")
if torch.isnan(inputs).any():
raise ValueError("Input contains NaN values")

# ...existing code...
103 changes: 79 additions & 24 deletions models/consciousness_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,38 @@ def forward(self, inputs, state=None, initial_state=None, deterministic=True, co
"""
Process inputs through consciousness architecture.
"""
# Initialize attention maps dictionary
# Initialize attention maps dictionary
attention_maps = {}

# Validate and process inputs
if not inputs:
raise ValueError("Inputs cannot be empty.")

# Allow for more flexible input combinations
required_modalities = {'visual', 'textual'} # Required modalities
missing_modalities = required_modalities - inputs.keys()
if missing_modalities:
# Auto-populate missing modalities with zero tensors
batch_size = next(iter(inputs.values())).size(0)
seq_len = next(iter(inputs.values())).size(1)
for modality in missing_modalities:
inputs[modality] = torch.zeros(batch_size, seq_len, self.hidden_dim, device=inputs[next(iter(inputs.keys()))].device)

# Check input dimensions
expected_dims = {
'attention': (None, 8, self.hidden_dim),
'memory': (None, 10, self.hidden_dim),
'visual': (None, None, self.hidden_dim),
'textual': (None, None, self.hidden_dim)
}

# Project inputs to correct dimension if needed
for modality, tensor in inputs.items():
if modality in expected_dims:
# Project if dimensions don't match
if tensor.size(-1) != self.hidden_dim:
inputs[modality] = self.input_projection(tensor)

batch_size = next(iter(inputs.values())).shape[0]
inputs = {k: torch.tensor(v, dtype=torch.float32) for k, v in inputs.items()}

Expand Down Expand Up @@ -250,51 +278,78 @@ def __init__(self, hidden_dim: int, num_heads: int, dropout_rate: float):

def forward(self, inputs: Dict[str, torch.Tensor], deterministic: bool = True):
"""Process multiple modalities and generate cross-modal attention maps."""
batch_size = next(iter(inputs.values())).size(0)
if not inputs:
raise ValueError("Empty input dictionary")

# Get dimensions from first input tensor
first_tensor = next(iter(inputs.values()))
batch_size = first_tensor.size(0)
hidden_dim = first_tensor.size(-1)

# Validate all inputs have same sequence length
seq_length = next(iter(inputs.values())).size(1)
for name, tensor in inputs.items():
if tensor.size(1) != seq_length:
raise ValueError(f"Sequence length mismatch for {name}: expected {seq_length}, got {tensor.size(1)}")

# Initialize combined state with correct dimensions
combined_state = torch.zeros(
batch_size, seq_length, hidden_dim,
device=first_tensor.device
)

attention_maps = {}
processed_states = {}

# First pass: Project all inputs
# Input validation
if not inputs:
raise ValueError("Empty input dictionary")

# Ensure all inputs have same dimensions
first_tensor = next(iter(inputs.values()))
expected_shape = first_tensor.shape[-1]
for name, tensor in inputs.items():
if tensor.shape[-1] != expected_shape:
raise ValueError(f"Mismatched dimensions for {name}: expected {expected_shape}, got {tensor.shape[-1]}")

# Project and reshape inputs
for modality, tensor in inputs.items():
processed = self.input_projection(tensor) # Use input_projection
# Ensure 3D shape for attention
if tensor.dim() == 2:
tensor = tensor.unsqueeze(1)
processed = self.input_projection(tensor)
processed_states[modality] = processed

# Initialize combined state with zeros matching the maximum sequence length
max_seq_length = max(tensor.size(1) for tensor in processed_states.values())
# Generate attention maps between all pairs
combined_state = torch.zeros(
batch_size, max_seq_length, self.hidden_dim,
batch_size, seq_length, self.hidden_dim,
device=next(iter(inputs.values())).device
)

# Generate attention maps between all modality pairs
for source in inputs.keys():
for target in inputs.keys():
for source in processed_states.keys():
for target in processed_states.keys():
if source != target:
query = processed_states[target]
query = processed_states[target]
key = processed_states[source]
value = processed_states[source]

# Ensure 3D shape for attention
if query.dim() == 2:
query = query.unsqueeze(1)
if key.dim() == 2:
key = key.unsqueeze(1)
if value.dim() == 2:
value = value.unsqueeze(1)

attn_output, attn_weights = self.attention(
query=query,
key=key,
value=value
)

# Store attention map
map_key = f"{target}-{source}"
attention_maps[map_key] = attn_weights

# Pad attn_output if necessary to match combined_state's sequence length
if attn_output.size(1) < max_seq_length:
pad_size = max_seq_length - attn_output.size(1)
attn_output = torch.nn.functional.pad(attn_output, (0, 0, 0, pad_size))
elif attn_output.size(1) > max_seq_length:
attn_output = attn_output[:, :max_seq_length, :]


attention_maps[f"{target}-{source}"] = attn_weights
combined_state = combined_state + attn_output

# ...existing code...
return combined_state, attention_maps

class InformationIntegration(nn.Module):
Expand Down
35 changes: 27 additions & 8 deletions models/consciousness_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,36 @@ def __init__(self, hidden_dim: int, num_heads: int, dropout_rate: float = 0.1):
# Add modality combination layer
self.modality_combination = nn.Linear(hidden_dim, hidden_dim)

def forward(self, inputs: Dict[str, torch.Tensor], deterministic: bool = True):
"""Process multiple modalities and generate cross-modal attention maps."""
batch_size = next(iter(inputs.values())).size(0)
seq_length = next(iter(inputs.values())).size(1)
def forward(self, inputs: Dict[str, torch.Tensor], deterministic: bool = True) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
# Input validation
if not inputs:
raise ValueError("Empty input dictionary")

# Get expected input dimension
first_tensor = next(iter(inputs.values()))
expected_shape = first_tensor.shape[-1]

# Define batch_size and seq_length
batch_size, seq_length, _ = first_tensor.size()

# Initialize attention_maps dictionary
attention_maps = {}

# Validate all inputs
for name, tensor in inputs.items():
if tensor.size(-1) != expected_shape:
raise ValueError(f"Mismatched input dimension for {name}: expected {expected_shape}, got {tensor.size(-1)}")
if tensor.dim() not in [2, 3]:
raise ValueError(f"Input {name} must be 2D or 3D tensor, got shape {tensor.shape}")
if torch.isnan(tensor).any():
raise ValueError(f"Input {name} contains NaN values")

# Process inputs
processed_states = {}

# First pass: Project all inputs
for modality, tensor in inputs.items():
processed = self.input_projection(tensor)
processed_states[modality] = processed
if tensor.dim() == 2:
tensor = tensor.unsqueeze(1) # Add sequence dimension
processed_states[modality] = self.input_projection(tensor)

# Initialize combined state with zeros
combined_state = torch.zeros(
Expand Down
23 changes: 23 additions & 0 deletions models/integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class InformationIntegration(nn.Module):
def forward(self, inputs, deterministic=True):
"""Process inputs with enhanced validation."""
# Input tensor validation
if isinstance(inputs, torch.Tensor):
if inputs.size(0) == 0 or inputs.size(1) == 0:
raise ValueError("Empty input dimensions")
if torch.isnan(inputs).any():
raise ValueError("Input contains NaN values")
if inputs.size(-1) != self.input_dim:
raise ValueError(f"Expected input dimension {self.input_dim}, got {inputs.size(-1)}")

# Process input after validation
processed = self.input_projection(inputs)
normed = self.layer_norm(processed)

if not deterministic:
normed = self.dropout(normed)

# Calculate integration metric (phi)
phi = torch.mean(torch.abs(normed), dim=(-2, -1))

return normed, phi
Loading

0 comments on commit f9ca348

Please sign in to comment.