Skip to content

Commit

Permalink
Add BasicReasoning module with logical, pattern, and causal reasoning…
Browse files Browse the repository at this point in the history
… components; implement confidence estimation and reasoning score calculation; add unit tests for reasoning metrics
  • Loading branch information
kasinadhsarma committed Dec 26, 2024
1 parent 7c08a63 commit 91a1f36
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 1 deletion.
Binary file added core/__pycache__/config.cpython-310.pyc
Binary file not shown.
Binary file added models/__pycache__/base_model.cpython-310.pyc
Binary file not shown.
Binary file added models/__pycache__/reasoning.cpython-310.pyc
Binary file not shown.
88 changes: 87 additions & 1 deletion models/reasoning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict

from models.base_model import BaseModel
from core.config import ModelConfig
Expand Down Expand Up @@ -102,3 +102,89 @@ def forward(
attention_mask=attention_mask,
context=context
)

class BasicReasoning(nn.Module):
def __init__(self, hidden_dim: int, num_heads: int = 4):
super().__init__()
self.hidden_dim = hidden_dim

# Logical reasoning components
self.logical_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2),
nn.LayerNorm(hidden_dim * 2),
nn.ReLU(),
nn.Linear(hidden_dim * 2, hidden_dim)
)

# Pattern recognition
self.pattern_recognition = nn.MultiheadAttention(
hidden_dim, num_heads=num_heads, batch_first=True
)

# Causal inference
self.causal_inference = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU()
)

# Reasoning confidence estimation
self.confidence_estimator = nn.Sequential(
nn.Linear(hidden_dim * 3, 1),
nn.Sigmoid()
)

def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
# Logical reasoning
logical_out = self.logical_layer(x)

# Pattern recognition through attention
pattern_out, pattern_weights = self.pattern_recognition(x, x, x)

# Causal inference by combining logical and pattern outputs
causal_input = torch.cat([logical_out, pattern_out], dim=-1)
causal_out = self.causal_inference(causal_input)

# Calculate reasoning confidence
confidence_input = torch.cat([logical_out, pattern_out, causal_out], dim=-1)
confidence = self.confidence_estimator(confidence_input)

# Calculate normalized reasoning scores
with torch.no_grad():
# Normalize using softmax for pattern weights
pattern_weights_norm = torch.softmax(pattern_weights, dim=-1)
pattern_score = torch.mean(pattern_weights_norm)

# Normalize cosine similarities to [0,1] range
logical_sim = torch.cosine_similarity(logical_out, x, dim=-1)
logical_score = torch.clamp((logical_sim + 1) / 2, 0, 1).mean()

causal_sim = torch.cosine_similarity(causal_out, x, dim=-1)
causal_score = torch.clamp((causal_sim + 1) / 2, 0, 1).mean()

return {
'output': causal_out,
'confidence': confidence,
'metrics': {
'logical_score': logical_score.item(),
'pattern_score': pattern_score.item(),
'causal_score': causal_score.item(),
'reasoning_weights': pattern_weights
}
}

def calculate_reasoning_score(self, metrics: Dict[str, float]) -> float:
"""Calculate overall reasoning capability score"""
weights = {
'logical': 0.4,
'pattern': 0.3,
'causal': 0.3
}

score = (
weights['logical'] * metrics['logical_score'] +
weights['pattern'] * metrics['pattern_score'] +
weights['causal'] * metrics['causal_score']
)

return score * 100 # Convert to percentage
Binary file not shown.
34 changes: 34 additions & 0 deletions tests/test_reasoning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
import torch
from models.reasoning import BasicReasoning

@pytest.fixture
def model():
return BasicReasoning(hidden_dim=128)

@pytest.fixture
def sample_input():
return torch.randn(2, 5, 128) # [batch_size, seq_len, hidden_dim]

class TestBasicReasoning:
def test_reasoning_scores(self, model, sample_input):
output = model(sample_input)

# Check all components exist
assert 'logical_score' in output['metrics']
assert 'pattern_score' in output['metrics']
assert 'causal_score' in output['metrics']

# Verify score ranges
for key in ['logical_score', 'pattern_score', 'causal_score']:
assert 0 <= output['metrics'][key] <= 1

# Calculate overall score
score = model.calculate_reasoning_score(output['metrics'])
assert 0 <= score <= 100

def test_confidence_estimation(self, model, sample_input):
output = model(sample_input)
assert 'confidence' in output
assert output['confidence'].shape == (2, 5, 1) # [batch_size, seq_len, 1]
assert torch.all(output['confidence'] >= 0) and torch.all(output['confidence'] <= 1)

0 comments on commit 91a1f36

Please sign in to comment.