Skip to content

Commit

Permalink
Enhance input validation in attention mechanisms and consciousness mo…
Browse files Browse the repository at this point in the history
…del; add padding for variable sequence lengths
  • Loading branch information
kasinadhsarma committed Dec 25, 2024
1 parent f9ca348 commit 39f5467
Show file tree
Hide file tree
Showing 22 changed files with 96 additions and 41 deletions.
Binary file modified models/__pycache__/attention.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/consciousness_model.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/consciousness_state.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/memory.cpython-310.pyc
Binary file not shown.
68 changes: 46 additions & 22 deletions models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ def forward(self, query, key_value, mask=None, training=None):
"""Forward pass of consciousness attention mechanism."""
# Input validation
if query.size(0) == 0 or query.size(1) == 0 or query.size(2) == 0:
raise ValueError("Query tensor cannot be empty")
raise ValueError("Empty input tensor")
if key_value.size(0) == 0 or key_value.size(1) == 0 or key_value.size(2) == 0:
raise ValueError("Key/Value tensor cannot be empty")
raise ValueError("Empty input tensor")
if query.size(0) != key_value.size(0):
raise ValueError("Batch size mismatch between query and key_value")
if query.size(1) != key_value.size(1):
raise ValueError("Sequence length mismatch between query and key_value")
if query.nelement() == 0 or key_value.nelement() == 0:
raise ValueError("Empty input tensor")

# Validate input dimensions
if query.size(-1) != self.hidden_dim or key_value.size(-1) != self.hidden_dim:
Expand Down Expand Up @@ -94,30 +100,48 @@ def __init__(self, hidden_dim: int, num_heads: int, head_dim: int, dropout_rate:
nn.Dropout(dropout_rate)
)

def forward(self, inputs: torch.Tensor,
memory_state: Optional[torch.Tensor] = None,
def _process_attention(self, inputs: torch.Tensor,
memory_state: Optional[torch.Tensor] = None,
deterministic: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""Process inputs through attention mechanism with residual connection."""
# Use attention mechanism
attention_output, attention_weights = self.attention(
query=inputs,
key_value=memory_state if memory_state is not None else inputs
)

# First residual connection and layer norm
attention_output = inputs + attention_output
normalized = self.layer_norm2(attention_output)

# Feed-forward network with residual
ff_output = self.ff_network(normalized)
output = attention_output + ff_output

return output, attention_weights

def forward(self, inputs: torch.Tensor,
memory_state: Optional[torch.Tensor] = None,
deterministic: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with optional deterministic mode."""
# Input validation
if inputs.size(0) == 0 or inputs.size(1) == 0 or inputs.size(2) == 0:
raise ValueError("Input tensor cannot be empty")
if inputs.numel() == 0:
raise ValueError("Empty input tensor")

if inputs.size(-1) != self.hidden_dim:
raise ValueError(f"Expected input dimension {self.hidden_dim}, got {inputs.size(-1)}")

# Layer normalization and attention
x = self.layer_norm1(inputs)
attended_output, attention_weights = self.attention(
x, x, mask=None,
training=not deterministic # Convert deterministic to training mode
)
# Get input dimensions
batch_size, *dims = inputs.size()

# First residual connection
x = inputs + attended_output
# Reshape input if needed to match expected 3D shape [batch, seq, features]
if len(dims) == 1:
inputs = inputs.unsqueeze(1) # Add sequence dimension
elif len(dims) > 2:
# Flatten all dimensions after batch into sequence dimension
inputs = inputs.view(batch_size, -1, dims[-1])

# Apply layer normalization
normalized = self.layer_norm1(inputs)

# Process through attention layers
output, attention_weights = self._process_attention(normalized, memory_state, deterministic)

# Feed-forward network with residual connection
y = self.layer_norm2(x)
y = self.ff_network(y) if not deterministic else self.ff_network.eval()(y)
output = x + y

return output, attention_weights
19 changes: 18 additions & 1 deletion models/consciousness_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,24 @@ def forward(self, inputs: Dict[str, torch.Tensor], deterministic: bool = True):
"""Process multiple modalities and generate cross-modal attention maps."""
if not inputs:
raise ValueError("Empty input dictionary")


# Get dimensions from largest input tensor
seq_lengths = {name: tensor.size(1) if tensor.dim() > 1 else 1
for name, tensor in inputs.items()}
max_seq_len = max(seq_lengths.values())

# Pad all inputs to match max sequence length
processed_inputs = {}
for name, tensor in inputs.items():
if tensor.dim() == 2: # [batch, features]
tensor = tensor.unsqueeze(1) # Add sequence dimension
if tensor.size(1) < max_seq_len:
# Pad sequence dimension to match max length
pad_size = max_seq_len - tensor.size(1)
tensor = torch.nn.functional.pad(tensor, (0, 0, 0, pad_size))
processed_inputs[name] = tensor

# Continue with regular processing using padded inputs
# Get dimensions from first input tensor
first_tensor = next(iter(inputs.values()))
batch_size = first_tensor.size(0)
Expand Down
Binary file modified tests/__pycache__/conftest.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file modified tests/__pycache__/test_environment.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
9 changes: 6 additions & 3 deletions tests/benchmarks/test_bigbench_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ def test_meta_learning(self, device, consciousness_model):
torch.manual_seed(i) # Ensure different random patterns
input_embedding = torch.randn(1, 64, 512, device=device) * (i + 1) # Vary input scale
if last_state is not None:
state = last_state.unsqueeze(0)
# Expand state to match sequence length
state = last_state.unsqueeze(0).expand(-1, 64, -1)
else:
state = torch.zeros(1, 1, consciousness_model.hidden_dim, device=device)
# Initialize state with correct sequence length
state = torch.zeros(1, 64, consciousness_model.hidden_dim, device=device)

output, metrics = consciousness_model(
{
Expand Down Expand Up @@ -156,7 +158,8 @@ def test_consciousness_emergence(self, device, consciousness_model):
torch.manual_seed(i)
# Use consistent sequence length and vary input scale
task_embedding = torch.randn(1, 64, 512, device=device) * (i + 1) # Vary input scale
base_state = torch.randn(1, 1, consciousness_model.hidden_dim, device=device)
# Initialize base_state with correct sequence length
base_state = torch.randn(1, 64, consciousness_model.hidden_dim, device=device)

output, metrics = consciousness_model(
{
Expand Down
Binary file modified tests/unit/__pycache__/test_base.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
41 changes: 26 additions & 15 deletions tests/unit/attention/test_attention_mechanisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,27 @@ def test_attention_edge_cases(self, attention_module):
seq_length = 8
input_dim = 128

# Test with empty input
empty_input = torch.empty(batch_size, seq_length, input_dim)
with pytest.raises(ValueError):
# Test with empty input (zero sized tensor)
empty_input = torch.zeros((batch_size, 0, input_dim))
with pytest.raises(ValueError, match="Empty input tensor"):
attention_module(empty_input, empty_input)

# Test with mismatched input dimensions
mismatched_input_q = torch.randn(batch_size, seq_length, input_dim)
mismatched_input_kv = torch.randn(batch_size, seq_length, input_dim // 2)
with pytest.raises(ValueError):
attention_module(mismatched_input_q, mismatched_input_kv)
# Test with zero dimension tensor
zero_dim = torch.tensor([])
with pytest.raises(ValueError, match="Empty input tensor"):
attention_module(zero_dim, zero_dim)

# Test with mismatched batch sizes
query = torch.randn(2, seq_length, input_dim)
key_value = torch.randn(3, seq_length, input_dim)
with pytest.raises(ValueError, match="Batch size mismatch"):
attention_module(query, key_value)

# Test with mismatched sequence lengths
query = torch.randn(batch_size, seq_length, input_dim)
key_value = torch.randn(batch_size, seq_length + 1, input_dim)
with pytest.raises(ValueError, match="Sequence length mismatch"):
attention_module(query, key_value)

class TestGlobalWorkspace:
@pytest.fixture
Expand Down Expand Up @@ -167,12 +178,12 @@ def test_global_workspace_edge_cases(self, workspace_module):
seq_length = 8
input_dim = 128

# Test with empty input
empty_input = torch.empty(batch_size, seq_length, input_dim)
with pytest.raises(ValueError):
# Test with empty input (zero sized tensor)
empty_input = torch.zeros((batch_size, 0, input_dim))
with pytest.raises(ValueError, match="Empty input tensor"):
workspace_module(empty_input)

# Test with mismatched input dimensions
mismatched_input = torch.randn(batch_size, seq_length, input_dim // 2)
with pytest.raises(ValueError):
workspace_module(mismatched_input)
# Test with zero dimension tensor
zero_dim = torch.tensor([])
with pytest.raises(ValueError, match="Empty input tensor"):
workspace_module(zero_dim)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 39f5467

Please sign in to comment.