Skip to content

Commit

Permalink
Enhance GlobalWorkspace to support dtype and device in zero padding; …
Browse files Browse the repository at this point in the history
…add setup fixture for default tensor settings
  • Loading branch information
kasinadhsarma committed Dec 27, 2024
1 parent e9e9d99 commit d65db2b
Show file tree
Hide file tree
Showing 52 changed files with 17 additions and 1 deletion.
Binary file modified __pycache__/test_init.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file modified core/__pycache__/config.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/abstract_reasoning.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/attention.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/base_model.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/consciousness.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/consciousness_model.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/consciousness_state.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/dynamic_attention.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/error_handling.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/ethical_safety.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/global_workspace.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/information_integration.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/intentionality.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/long_term_memory.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/memory.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/reasoning.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/self_awareness.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/simulated_emotions.cpython-310.pyc
Binary file not shown.
Binary file modified models/__pycache__/working_memory.cpython-310.pyc
Binary file not shown.
8 changes: 7 additions & 1 deletion models/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def forward(self, modalities: Union[Dict[str, torch.Tensor], None] = None, senso

# Pad remaining slots with zeros if needed
while len(integrated_features) < self.num_modalities:
zero_features = torch.zeros_like(integrated_features[0])
zero_features = torch.zeros_like(
integrated_features[0],
dtype=integrated_features[0].dtype,
device=integrated_features[0].device
)
integrated_features.append(zero_features)

# Stack and reshape for attention
Expand Down Expand Up @@ -121,6 +125,7 @@ def forward(self, modalities: Union[Dict[str, torch.Tensor], None] = None, senso
self.num_modalities - attention_weights.size(1),
seq_len,
attention_weights.size(3),
dtype=attention_weights.dtype,
device=attention_weights.device
)
attention_weights = torch.cat([attention_weights, padding], dim=1)
Expand All @@ -132,6 +137,7 @@ def forward(self, modalities: Union[Dict[str, torch.Tensor], None] = None, senso
attention_weights.size(1),
3 - attention_weights.size(2),
attention_weights.size(3),
dtype=attention_weights.dtype,
device=attention_weights.device
)
attention_weights = torch.cat([attention_weights, padding], dim=2)
Expand Down
Binary file modified tests/__pycache__/conftest.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/__pycache__/test_environment.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/__pycache__/test_lint.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file modified tests/__pycache__/test_reasoning.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,13 @@ def jit_compile():
def decorator(func):
return torch.jit.script(func)
return decorator

@pytest.fixture(autouse=True)
def setup_torch_defaults():
# Use recommended alternatives to set_default_tensor_type
torch.set_default_dtype(torch.float32)
# Only set default device if CUDA is available
if torch.cuda.is_available():
torch.set_default_device('cuda')
else:
torch.set_default_device('cpu')
Binary file modified tests/unit/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file modified tests/unit/__pycache__/test_base.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit d65db2b

Please sign in to comment.