Skip to content

Commit

Permalink
Integrate LongTermMemory into ConsciousnessModel; add tests for long-…
Browse files Browse the repository at this point in the history
…term memory retrieval and integration; update existing tests to validate retrieved_memory metrics
  • Loading branch information
kasinadhsarma committed Dec 26, 2024
1 parent 3d82857 commit ac5d8d2
Show file tree
Hide file tree
Showing 16 changed files with 339 additions and 26 deletions.
Binary file modified models/__pycache__/consciousness.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/consciousness_model.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/consciousness_state.cpython-310.pyc
Binary file not shown.
Binary file not shown.
100 changes: 89 additions & 11 deletions models/consciousness.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .information_integration import InformationIntegration
from .self_awareness import SelfAwareness # Add this import
from .dynamic_attention import DynamicAttention
from .long_term_memory import LongTermMemory

class MultiHeadAttention(nn.Module):
"""Custom MultiHeadAttention implementation"""
Expand Down Expand Up @@ -71,6 +72,15 @@ def forward(self, inputs: torch.Tensor, memory_state: Optional[torch.Tensor] = N
output = self.output_layer(integrated)
output = self.layer_norm(output)

# Add assertion to ensure output has correct hidden_dim
assert output.shape[-1] == self.hidden_dim, (
f"GlobalWorkspace output has hidden_dim {output.shape[-1]}, expected {self.hidden_dim}"
)

# Add logging for debugging
print(f"GlobalWorkspace output shape: {output.shape}")
print(f"memory_state shape: {memory_state.shape}")

return output, memory_state

class ConsciousnessModel(nn.Module):
Expand Down Expand Up @@ -121,6 +131,14 @@ def __init__(self, hidden_dim: int, num_heads: int, num_layers: int, num_states:
dropout_rate=dropout_rate
)

# Long-term memory
self.long_term_memory = LongTermMemory(
input_dim=self.input_dim,
hidden_dim=hidden_dim,
memory_size=1000,
dropout_rate=dropout_rate
)

# State tracking
self.previous_state = None

Expand Down Expand Up @@ -170,6 +188,8 @@ def update_goals(self, current_state: torch.Tensor):
def forward(self, inputs: Dict[str, torch.Tensor],
state: Optional[torch.Tensor] = None,
deterministic: bool = True) -> Tuple[torch.Tensor, Dict]:
# Initialize metrics dictionary at the start
metrics = {}

# Get device from inputs
device = next(iter(inputs.values())).device
Expand Down Expand Up @@ -199,12 +219,25 @@ def forward(self, inputs: Dict[str, torch.Tensor],
# Process through global workspace with reshaped state
conscious_out, memory_state = self.global_workspace(attn_out, state, deterministic)

# Add assertion to ensure conscious_out has correct hidden_dim
assert conscious_out.shape[-1] == self.hidden_dim, (
f"conscious_out has hidden_dim {conscious_out.shape[-1]}, expected {self.hidden_dim}"
)

# Add logging to verify conscious_out dimensions
print(f"conscious_out shape: {conscious_out.shape}")

# Process through self-awareness
aware_state, awareness_metrics = self.self_awareness(
conscious_out,
previous_state=self.previous_state
)

# Add assertion to ensure aware_state has correct hidden_dim
assert aware_state.shape[-1] == self.hidden_dim, (
f"aware_state has hidden_dim {aware_state.shape[-1]}, expected {self.hidden_dim}"
)

# Update previous state
self.previous_state = aware_state.detach()

Expand All @@ -213,17 +246,62 @@ def forward(self, inputs: Dict[str, torch.Tensor],

# Update goals based on conscious output
self.update_goals(conscious_out)

# Update metrics
metrics = {
'attention_weights': attention_metrics,
'memory_state': memory_state,
'phi': phi,
'attention_maps': attention_metrics,
'goal_state': self.goal_state,
'context_state': self.context_state,
**awareness_metrics
}

# Store memory with correct dimensions
memory_to_store = conscious_out.detach() # Remove mean reduction

# Use long_term_memory instead of memory
try:
# Ensure memory_to_store has correct shape [batch_size, hidden_dim]
memory_to_store = conscious_out.mean(dim=1) if len(conscious_out.shape) == 3 else conscious_out

# Store memory
self.long_term_memory.store_memory(memory_to_store)

# Retrieve memory using current state as query
retrieved_memory = self.long_term_memory.retrieve_memory(memory_to_store)

# Ensure retrieved memory has correct shape
if retrieved_memory.shape != (memory_to_store.shape[0], self.hidden_dim):
retrieved_memory = retrieved_memory.view(memory_to_store.shape[0], self.hidden_dim)

metrics['retrieved_memory'] = retrieved_memory

except Exception as e:
print(f"Memory operation error: {e}")
# Create zero tensor with correct shape
metrics['retrieved_memory'] = torch.zeros(
inputs['attention'].shape[0],
self.hidden_dim,
device=inputs['attention'].device
)

# Average over sequence length to get [batch_size, hidden_dim]
query = conscious_out.mean(dim=1) if len(conscious_out.shape) > 2 else conscious_out
print(f"query shape: {query.shape}")

# Ensure query has correct shape before memory retrieval
if query.dim() == 1:
query = query.unsqueeze(0)

# Retrieve memory and ensure it's in metrics
try:
retrieved_memory = self.long_term_memory.retrieve_memory(query)
print(f"retrieved_memory shape: {retrieved_memory.shape}")
metrics['retrieved_memory'] = retrieved_memory
except Exception as e:
print(f"Memory retrieval error: {e}")
metrics['retrieved_memory'] = torch.zeros(
query.size(0),
self.hidden_dim,
device=query.device
)

# Update remaining metrics
metrics.update(attention_metrics)
metrics['goal_state'] = self.goal_state
metrics['context_state'] = self.context_state
metrics['phi'] = phi

return aware_state, metrics

Expand Down
31 changes: 30 additions & 1 deletion models/consciousness_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .memory import WorkingMemory, InformationIntegration
from .consciousness_state import CognitiveProcessIntegration, ConsciousnessStateManager
from .error_handling import ErrorHandler, ErrorCorrection, validate_state
from .long_term_memory import LongTermMemory # Add import

torch.set_default_dtype(torch.float32)
torch.set_default_device('cpu') # or 'cuda' if using GPU
Expand Down Expand Up @@ -131,6 +132,13 @@ def __init__(self, hidden_dim: int, num_heads: int, num_layers: int, num_states:

self.target_cognition_percentage = 90.0 # Target cognition percentage

# Add long-term memory component
self.long_term_memory = LongTermMemory(
input_dim=hidden_dim,
hidden_dim=hidden_dim,
dropout_rate=dropout_rate
)

def add_meta_learning_layer(self):
"""Add meta-learning capabilities"""
self.meta_learner = nn.ModuleDict({
Expand Down Expand Up @@ -669,6 +677,27 @@ def forward(self, inputs, state=None, initial_state=None, deterministic=True, co
metrics['cognition_progress'] = cognition_progress
self.logger.debug(f"Cognition Progress: {cognition_progress}%")

# Add long-term memory retrieval with proper shape handling
query = consciousness_state.mean(dim=1) # [batch_size, hidden_dim]
if query.dim() == 1:
query = query.unsqueeze(0) # Add batch dimension if missing

# Ensure query has correct batch size
batch_size = next(iter(inputs.values())).size(0)
if query.size(0) == 1 and batch_size > 1:
query = query.expand(batch_size, -1)

# Retrieve and store memory
retrieved_memory = self.long_term_memory.retrieve_memory(query) # Should return [batch_size, hidden_dim]
self.long_term_memory.store_memory(new_state.detach())

# Double check retrieved memory shape
if retrieved_memory.size(0) != batch_size:
retrieved_memory = retrieved_memory.expand(batch_size, -1)

# Include retrieved_memory in metrics
metrics['retrieved_memory'] = retrieved_memory

end_time = time.time() # End profiling
self.logger.debug(f"forward pass took {end_time - start_time:.6f} seconds")

Expand Down Expand Up @@ -699,7 +728,7 @@ def get_config(self) -> Dict[str, Any]:
}

@classmethod
def create_default_config(cls) -> Dict[str, Any]:
def create_default_config(cls) -> Dict[str, Any]: # Fixed syntax error here
"""Create default model configuration."""
return {
'hidden_dim': 512,
Expand Down
13 changes: 13 additions & 0 deletions models/consciousness_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,16 @@ def get_rl_loss(self, state_value, reward, next_state_value, gamma=0.99):
# Value loss (MSE)
value_loss = torch.mean(td_error ** 2)
return value_loss, td_error

def test_long_term_memory_integration(self, model, sample_input):
# ...existing code...
output, metrics = model(sample_input)

# ...existing code...

# Add assertion to check if 'retrieved_memory' is in metrics and has correct shape
assert 'retrieved_memory' in metrics, "retrieved_memory not found in metrics"
assert metrics['retrieved_memory'].shape == (sample_input['query'].size(0), 128), (
f"retrieved_memory has shape {metrics['retrieved_memory'].shape}, expected ({sample_input['query'].size(0)}, 128)"
)
# ...existing code...
111 changes: 111 additions & 0 deletions models/long_term_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class LongTermMemory(nn.Module):
"""Long-term memory component for maintaining and recalling episodic information."""

def __init__(self, input_dim: int, hidden_dim: int, memory_size: int = 1000, dropout_rate: float = 0.1):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.memory_size = memory_size

# Memory cells
self.memory_rnn = nn.LSTM(
input_size=input_dim,
hidden_size=hidden_dim,
num_layers=2,
dropout=dropout_rate,
batch_first=True
)

# Initialize memory storage with correct device and shape
self.register_buffer('memory_storage', torch.zeros(memory_size, hidden_dim))
self.memory_index = 0

# Output projection
self.output_projection = nn.Linear(hidden_dim, hidden_dim)

# Layer normalization
self.layer_norm = nn.LayerNorm(hidden_dim)

def store_memory(self, memory: torch.Tensor):
"""Store memory in the long-term memory storage."""
# Add assertion to ensure memory has shape [batch_size, hidden_dim]
assert memory.dim() == 2 and memory.size(1) == self.hidden_dim, (
f"Memory has shape {memory.shape}, expected [batch_size, {self.hidden_dim}]"
)

batch_size = memory.size(0)

# Store exact memory values without normalization
if self.memory_index + batch_size > self.memory_size:
overflow = (self.memory_index + batch_size) - self.memory_size
self.memory_storage[self.memory_index:] = memory[:batch_size - overflow].detach()
self.memory_storage[:overflow] = memory[batch_size - overflow:].detach()
self.memory_index = overflow
else:
self.memory_storage[self.memory_index:self.memory_index + batch_size] = memory.detach()
self.memory_index += batch_size

def retrieve_memory(self, query):
"""Retrieve relevant memories based on query."""
try:
# Ensure query has correct shape [batch_size, hidden_dim]
if query.dim() == 1:
query = query.unsqueeze(0)
elif query.dim() > 2:
query = query.view(-1, self.hidden_dim)

batch_size = query.size(0)

# Handle empty memory case
if self.memory_index == 0:
return query # Return query itself if no memories stored

# Get valid memories
if self.memory_index < self.memory_size:
valid_memories = self.memory_storage[:self.memory_index]
else:
valid_memories = self.memory_storage

# Ensure we have at least one memory
if valid_memories.size(0) == 0:
return query

# Normalize for similarity computation only
query_norm = F.normalize(query, p=2, dim=1)
memories_norm = F.normalize(valid_memories, p=2, dim=1)

# Compute cosine similarity
similarity = torch.matmul(query_norm, memories_norm.t())

# Get attention weights through softmax
attention = F.softmax(similarity / 0.1, dim=1) # Temperature scaling

# Use original memories for weighted sum
retrieved = torch.matmul(attention, valid_memories)

return retrieved

except Exception as e:
print(f"Memory retrieval error: {e}")
return query.clone() # Return query itself in case of error

def forward(self, x):
# Run LSTM
output, (h_n, c_n) = self.memory_rnn(x)

# Project output
output = self.output_projection(output)
output = self.layer_norm(output)

# Retrieve memory ensuring batch size consistency
retrieved_memory = self.retrieve_memory(x)
# Ensure retrieved_memory has the same batch size as input
retrieved_memory = retrieved_memory.view(x.size(0), -1)
metrics = {'retrieved_memory': retrieved_memory}

return output, (h_n, c_n)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/__pycache__/test_self_awareness.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
19 changes: 19 additions & 0 deletions tests/test_consciousness.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,5 +504,24 @@ def test_logging_of_cognition_progress(self, model, sample_input, deterministic,
# Check that the cognition progress log is present
assert any("Cognition Progress" in message for message in caplog.text.splitlines())

def test_long_term_memory_integration(self, model, sample_input):
"""Test integration of long-term memory in the consciousness model."""
model.eval()
state = torch.zeros(sample_input['attention'].shape[0], model.hidden_dim)

# Run forward pass
with torch.no_grad():
output, metrics = model(sample_input, initial_state=state, deterministic=True)
print(f"metrics keys: {metrics.keys()}")

# Add assertion to check if 'retrieved_memory' is in metrics
assert 'retrieved_memory' in metrics, "retrieved_memory not found in metrics"

# Additional assertions can be added here to verify the correctness of retrieved_memory
retrieved_memory = metrics['retrieved_memory']
assert retrieved_memory.shape == (sample_input['attention'].shape[0], model.hidden_dim), (
f"retrieved_memory has shape {retrieved_memory.shape}, expected ({sample_input['attention'].shape[0]}, {model.hidden_dim})"
)

if __name__ == '__main__':
pytest.main([__file__])
Loading

0 comments on commit ac5d8d2

Please sign in to comment.