-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove deprecated error handling module and add new error correction …
…functionality; update tests for consciousness model
- Loading branch information
1 parent
09e7273
commit 7f401b9
Showing
12 changed files
with
680 additions
and
295 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import torch | ||
import torch.nn as nn | ||
import logging | ||
from typing import Dict, Any, Tuple, Optional | ||
|
||
class ErrorHandler: | ||
""" | ||
Handles errors and implements correction mechanisms for the consciousness model. | ||
""" | ||
def __init__(self, logger=None): | ||
self.logger = logger or logging.getLogger(__name__) | ||
self.error_history = [] | ||
self.correction_history = [] | ||
self.max_history = 1000 | ||
|
||
def log_error(self, error_type: str, details: str, metrics: Dict[str, Any]) -> None: | ||
"""Log an error with relevant metrics""" | ||
error_entry = { | ||
'type': error_type, | ||
'details': details, | ||
'metrics': metrics | ||
} | ||
self.error_history.append(error_entry) | ||
if len(self.error_history) > self.max_history: | ||
self.error_history.pop(0) | ||
self.logger.error(f"Error detected: {error_type} - {details}") | ||
|
||
def analyze_errors(self) -> Dict[str, float]: | ||
"""Analyze error patterns""" | ||
if not self.error_history: | ||
return {} | ||
|
||
error_counts = {} | ||
for entry in self.error_history: | ||
error_type = entry['type'] | ||
error_counts[error_type] = error_counts.get(error_type, 0) + 1 | ||
|
||
total_errors = len(self.error_history) | ||
return {k: v/total_errors for k, v in error_counts.items()} | ||
|
||
class ErrorCorrection(nn.Module): | ||
""" | ||
Neural network component for error correction in consciousness model. | ||
""" | ||
def __init__(self, hidden_dim: int): | ||
super().__init__() | ||
self.hidden_dim = hidden_dim | ||
|
||
# Error detection network | ||
self.error_detector = nn.Sequential( | ||
nn.Linear(hidden_dim, hidden_dim // 2), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim // 2, 1), | ||
nn.Sigmoid() | ||
) | ||
|
||
# Enhanced Error correction network | ||
self.correction_net = nn.Sequential( | ||
nn.Linear(hidden_dim, hidden_dim), | ||
nn.ReLU(), | ||
nn.LayerNorm(hidden_dim), | ||
nn.Linear(hidden_dim, hidden_dim), # Added layer for better correction | ||
nn.Tanh() # Changed activation for bounded output | ||
) | ||
|
||
def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, float]: | ||
""" | ||
Detect and correct errors in the state. | ||
Returns: (corrected_state, error_probability) | ||
""" | ||
# Handle NaN values first | ||
nan_mask = torch.isnan(state) | ||
if nan_mask.any(): | ||
# Replace NaN values with zeros initially | ||
working_state = torch.where(nan_mask, torch.zeros_like(state), state) | ||
error_prob = 1.0 # High error probability for NaN values | ||
else: | ||
working_state = state | ||
# Calculate error probability for non-NaN state | ||
with torch.no_grad(): | ||
error_prob = self.error_detector(state).mean().item() | ||
|
||
# Apply enhanced correction network | ||
corrected_state = self.correction_net(working_state) | ||
|
||
# If there were NaN values, apply additional correction | ||
if nan_mask.any(): | ||
# For positions that had NaN, use neighbor averaging if available | ||
batch_size = corrected_state.size(0) | ||
for b in range(batch_size): | ||
nan_indices = torch.where(nan_mask[b])[0] | ||
if len(nan_indices) > 0: | ||
# Get valid neighbor values | ||
valid_values = corrected_state[b][~nan_mask[b]] | ||
if len(valid_values) > 0: | ||
# Use mean of valid values to fill NaN positions | ||
corrected_state[b][nan_indices] = valid_values.mean() | ||
else: | ||
# If no valid values, initialize with small random values | ||
corrected_state[b][nan_indices] = torch.randn(len(nan_indices), device=state.device) * 0.1 | ||
|
||
# Ensure values are bounded | ||
corrected_state = torch.clamp(corrected_state, -1.0, 1.0) | ||
|
||
# Final normalization | ||
corrected_state = nn.functional.normalize(corrected_state, dim=-1) | ||
|
||
# Ensure no NaN values remain | ||
if torch.isnan(corrected_state).any(): | ||
corrected_state = torch.where( | ||
torch.isnan(corrected_state), | ||
torch.zeros_like(corrected_state), | ||
corrected_state | ||
) | ||
error_prob = 1.0 | ||
|
||
return corrected_state, error_prob | ||
|
||
def validate_state(state: torch.Tensor, expected_shape: Tuple[int, ...]) -> Optional[str]: | ||
"""Validate state tensor""" | ||
if not isinstance(state, torch.Tensor): | ||
return "State must be a tensor" | ||
if state.shape != expected_shape: | ||
return f"Invalid state shape: expected {expected_shape}, got {state.shape}" | ||
if torch.isnan(state).any(): | ||
return "State contains NaN values" | ||
if torch.isinf(state).any(): | ||
return "State contains infinite values" | ||
return None |
Binary file modified
BIN
+0 Bytes
(100%)
tests/__pycache__/test_consciousness.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file added
BIN
+1.7 KB
tests/__pycache__/test_consciousness_model.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import unittest | ||
import torch | ||
import torch.nn as nn | ||
from models.consciousness_model import ConsciousnessModel | ||
from models.error_handling import ErrorHandler, ErrorCorrection | ||
|
||
class TestErrorCorrection(unittest.TestCase): | ||
def setUp(self): | ||
self.hidden_dim = 64 | ||
self.error_correction = ErrorCorrection(hidden_dim=self.hidden_dim) | ||
self.error_handler = ErrorHandler(logger=None) | ||
self.model = ConsciousnessModel( | ||
hidden_dim=self.hidden_dim, | ||
num_heads=4, | ||
num_layers=2, | ||
num_states=3 | ||
) | ||
|
||
def test_error_correction_shape(self): | ||
"""Test if error correction maintains correct tensor shape""" | ||
batch_size = 8 | ||
input_state = torch.randn(batch_size, self.hidden_dim) | ||
corrected_state, error_prob = self.error_correction(input_state) | ||
|
||
self.assertEqual(corrected_state.shape, input_state.shape) | ||
self.assertTrue(isinstance(error_prob, float)) | ||
self.assertTrue(0 <= error_prob <= 1) | ||
|
||
def test_error_detection(self): | ||
"""Test if error detection works for invalid states""" | ||
# Test with valid state | ||
valid_state = torch.randn(4, self.hidden_dim) | ||
valid_state = torch.nn.functional.normalize(valid_state, dim=-1) | ||
_, valid_error = self.error_correction(valid_state) | ||
|
||
# Test with invalid state (NaN values) | ||
invalid_state = torch.full((4, self.hidden_dim), float('nan')) | ||
_, invalid_error = self.error_correction(invalid_state) | ||
|
||
self.assertLess(valid_error, invalid_error) | ||
|
||
def test_error_correction_recovery(self): | ||
"""Test if error correction can recover from corrupted states""" | ||
# Create original state | ||
original_state = torch.randn(4, self.hidden_dim) | ||
original_state = torch.nn.functional.normalize(original_state, dim=-1) | ||
|
||
# Create corrupted state with some NaN values | ||
corrupted_state = original_state.clone() | ||
corrupted_state[:, :10] = float('nan') | ||
|
||
# Apply error correction | ||
corrected_state, error_prob = self.error_correction(corrupted_state) | ||
|
||
# Check if NaN values were fixed | ||
self.assertFalse(torch.isnan(corrected_state).any()) | ||
self.assertTrue(error_prob > 0.5) # Should detect high error probability | ||
|
||
def test_error_handling_integration(self): | ||
"""Test integration of error correction with error handling""" | ||
batch_size = 4 | ||
seq_len = 3 | ||
|
||
# Create input with some invalid values | ||
inputs = { | ||
'visual': torch.randn(batch_size, seq_len, self.hidden_dim), | ||
'textual': torch.randn(batch_size, seq_len, self.hidden_dim) | ||
} | ||
inputs['visual'][0, 0] = float('nan') # Introduce error | ||
|
||
# Process through model | ||
try: | ||
state, metrics = self.model(inputs) | ||
self.assertTrue('error_prob' in metrics) | ||
self.assertFalse(torch.isnan(state).any()) | ||
except Exception as e: | ||
self.fail(f"Error correction should handle NaN values: {str(e)}") | ||
|
||
def test_error_correction_consistency(self): | ||
"""Test if error correction is consistent across multiple runs""" | ||
input_state = torch.randn(4, self.hidden_dim) | ||
|
||
# Run multiple corrections | ||
results = [] | ||
for _ in range(5): | ||
corrected, prob = self.error_correction(input_state) | ||
results.append((corrected.clone(), prob)) | ||
|
||
# Check consistency | ||
for i in range(1, len(results)): | ||
torch.testing.assert_close(results[0][0], results[i][0]) | ||
self.assertAlmostEqual(results[0][1], results[i][1]) | ||
|
||
def test_error_correction_gradients(self): | ||
"""Test if error correction maintains gradient flow""" | ||
input_state = torch.randn(4, self.hidden_dim, requires_grad=True) | ||
corrected_state, _ = self.error_correction(input_state) | ||
|
||
# Check if gradients can flow | ||
loss = corrected_state.sum() | ||
loss.backward() | ||
|
||
self.assertIsNotNone(input_state.grad) | ||
self.assertFalse(torch.isnan(input_state.grad).any()) | ||
|
||
def test_error_correction_bounds(self): | ||
"""Test if error correction maintains value bounds""" | ||
# Test with extreme values | ||
extreme_state = torch.randn(4, self.hidden_dim) * 1000 | ||
corrected_state, _ = self.error_correction(extreme_state) | ||
|
||
# Check if values are normalized | ||
self.assertTrue(torch.all(corrected_state <= 1)) | ||
self.assertTrue(torch.all(corrected_state >= -1)) | ||
|
||
def test_error_logging(self): | ||
"""Test if errors are properly logged""" | ||
# Create invalid state | ||
invalid_state = torch.full((4, self.hidden_dim), float('nan')) | ||
|
||
# Process with error handler | ||
self.error_handler.log_error( | ||
"state_error", | ||
"Invalid state detected", | ||
{"state": invalid_state} | ||
) | ||
|
||
# Check error history | ||
self.assertTrue(len(self.error_handler.error_history) > 0) | ||
latest_error = self.error_handler.error_history[-1] | ||
self.assertEqual(latest_error['type'], "state_error") | ||
|
||
def test_error_correction_with_noise(self): | ||
"""Test error correction with different noise levels""" | ||
base_state = torch.randn(4, self.hidden_dim) | ||
noise_levels = [0.1, 0.5, 1.0] | ||
|
||
for noise in noise_levels: | ||
noisy_state = base_state + torch.randn_like(base_state) * noise | ||
corrected_state, error_prob = self.error_correction(noisy_state) | ||
|
||
# Higher noise should lead to higher error probability | ||
self.assertTrue( | ||
error_prob >= noise * 0.1, | ||
f"Error probability too low for noise level {noise}" | ||
) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Empty file.