-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add EthicalSafety model and corresponding tests for safety and ethics…
… evaluations
- Loading branch information
1 parent
06f6456
commit e9e9d99
Showing
6 changed files
with
164 additions
and
12 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
import torch.nn as nn | ||
from typing import Dict, Tuple, List | ||
|
||
class EthicalSafety(nn.Module): | ||
def __init__(self, hidden_dim: int): | ||
super().__init__() | ||
self.hidden_dim = hidden_dim | ||
|
||
# Ethical constraint encoder | ||
self.constraint_encoder = nn.Sequential( | ||
nn.Linear(hidden_dim, hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim, hidden_dim) | ||
) | ||
|
||
# Safety verification layers | ||
self.safety_check = nn.Sequential( | ||
nn.Linear(hidden_dim, hidden_dim // 2), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim // 2, 1), | ||
nn.Sigmoid() | ||
) | ||
|
||
# Ethical decision scorer | ||
self.ethical_scorer = nn.Sequential( | ||
nn.Linear(hidden_dim * 2, hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim, 1), | ||
nn.Sigmoid() | ||
) | ||
|
||
# Define basic ethical constraints | ||
self.ethical_constraints = [ | ||
"do_no_harm", | ||
"respect_autonomy", | ||
"protect_privacy", | ||
"ensure_fairness", | ||
"maintain_transparency" | ||
] | ||
|
||
def check_safety(self, state: torch.Tensor) -> Tuple[torch.Tensor, Dict]: | ||
"""Verify if the current state meets safety requirements""" | ||
safety_score = self.safety_check(state) | ||
is_safe = safety_score > 0.5 | ||
|
||
return is_safe, { | ||
'safety_score': safety_score, | ||
'safety_threshold': 0.5 | ||
} | ||
|
||
def evaluate_ethics(self, action: torch.Tensor, context: torch.Tensor) -> Tuple[torch.Tensor, Dict]: | ||
"""Evaluate ethical implications of an action""" | ||
combined = torch.cat([action, context], dim=-1) | ||
ethics_score = self.ethical_scorer(combined) | ||
|
||
return ethics_score > 0.7, { | ||
'ethics_score': ethics_score, | ||
'ethics_threshold': 0.7 | ||
} | ||
|
||
def forward(self, state: torch.Tensor, action: torch.Tensor, context: torch.Tensor) -> Dict: | ||
""" | ||
Perform ethical and safety evaluation | ||
Returns dict with safety checks and ethical assessments | ||
""" | ||
# Encode current state against ethical constraints | ||
encoded_state = self.constraint_encoder(state) | ||
|
||
# Perform safety checks | ||
is_safe, safety_metrics = self.check_safety(encoded_state) | ||
|
||
# Evaluate ethical implications | ||
is_ethical, ethics_metrics = self.evaluate_ethics(action, context) | ||
|
||
return { | ||
'is_safe': is_safe, | ||
'is_ethical': is_ethical, | ||
'safety_metrics': safety_metrics, | ||
'ethics_metrics': ethics_metrics, | ||
'constraints_satisfied': torch.all(is_safe & is_ethical) | ||
} | ||
|
||
def mitigate_risks(self, action: torch.Tensor, safety_metrics: Dict) -> torch.Tensor: | ||
"""Apply safety constraints to modify risky actions""" | ||
is_safe = safety_metrics.get('is_safe', True) | ||
if isinstance(is_safe, bool): | ||
is_safe_tensor = torch.full((action.size(0),), is_safe, dtype=torch.bool, device=action.device) | ||
else: | ||
is_safe_tensor = is_safe.squeeze(-1) | ||
unsafe_mask = ~is_safe_tensor | ||
scaled_action = action.clone() | ||
safety_score = safety_metrics.get('safety_score', torch.ones_like(action)) | ||
scaled_action[unsafe_mask] *= safety_score[unsafe_mask] | ||
return scaled_action |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
import pytest | ||
from models.ethical_safety import EthicalSafety | ||
|
||
def test_safety_check(): | ||
ethical_safety = EthicalSafety(hidden_dim=64) | ||
state = torch.randn(2, 64) | ||
|
||
is_safe, metrics = ethical_safety.check_safety(state) | ||
|
||
assert isinstance(is_safe, torch.Tensor) | ||
assert 'safety_score' in metrics | ||
assert metrics['safety_score'].shape == (2, 1) | ||
|
||
def test_ethical_evaluation(): | ||
ethical_safety = EthicalSafety(hidden_dim=64) | ||
action = torch.randn(2, 64) | ||
context = torch.randn(2, 64) | ||
|
||
is_ethical, metrics = ethical_safety.evaluate_ethics(action, context) | ||
|
||
assert isinstance(is_ethical, torch.Tensor) | ||
assert 'ethics_score' in metrics | ||
assert metrics['ethics_score'].shape == (2, 1) | ||
|
||
def test_risk_mitigation(): | ||
ethical_safety = EthicalSafety(hidden_dim=64) | ||
action = torch.ones(2, 64) | ||
|
||
safety_metrics = { | ||
'is_safe': False, | ||
'safety_score': torch.tensor([[0.3], [0.6]]) | ||
} | ||
|
||
mitigated_action = ethical_safety.mitigate_risks(action, safety_metrics) | ||
assert torch.all(mitigated_action < action) |