Skip to content

Commit

Permalink
fix: fixed grad shape in QuantumModelAutograd.backward
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Oct 22, 2024
1 parent 123f053 commit e4e62c1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 99 deletions.
4 changes: 4 additions & 0 deletions src/qiboml/models/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def output_shape(self):
@dataclass
class Samples(QuantumDecoding):

def __post_init__(self):
super().__post_init__()
self.analytic = False

def forward(self, x: Circuit) -> ndarray:
return self.backend.cast(super().__call__(x).samples(), self.backend.precision)

Expand Down
2 changes: 1 addition & 1 deletion src/qiboml/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,5 @@ def backward(ctx, grad_output: torch.Tensor):
None,
None,
None,
*(torch.vstack(gradients) @ grad_output),
*(torch.vstack(gradients).view((-1,) + grad_output.shape) @ grad_output.T),
)
93 changes: 0 additions & 93 deletions tests/test_backprop.py

This file was deleted.

11 changes: 6 additions & 5 deletions tests/test_models_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def random_parameters(frontend, model):
if frontend.__name__ == "qiboml.models.pytorch":
new_params = {}
for k, v in model.state_dict().items():
new_params.update({k: v + frontend.torch.randn(v.shape)})
new_params.update({k: v + frontend.torch.randn(v.shape) / 10})
elif frontend.__name__ == "qiboml.models.keras":
new_params = [frontend.tf.random.uniform(model.get_weights()[0].shape)]
return new_params
Expand Down Expand Up @@ -212,8 +212,8 @@ def test_encoding(backend, frontend, layer):
def test_decoding(backend, frontend, layer, analytic):
if frontend.__name__ == "qiboml.models.keras":
pytest.skip("keras interface not ready.")
if backend.name != "pytorch":
pytest.skip("Non pytorch differentiatio is not working yet.")
if backend.name not in ("pytorch", "jax"):
pytest.skip("Non pytorch/jax differentiation is not working yet.")
if analytic and not layer is dec.Expectation:
pytest.skip("Unused analytic argument.")
nqubits = 3
Expand All @@ -238,13 +238,15 @@ def test_decoding(backend, frontend, layer, analytic):
kwargs["analytic"] = analytic
decoding_layer = layer(nqubits, decoding_qubits, **kwargs)

if not decoding_layer.analytic:
pytest.skip("PSR differentiation is not working yet.")

q_model = frontend.QuantumModel(
encoding_layer, training_layer, decoding_layer, differentiation="Jax"
)

data = random_tensor(frontend, (100, dim))
target = prepare_targets(frontend, q_model, data)
print("> Training the pure quantum model...")
backprop_test(frontend, q_model, data, target)

model = build_sequential_model(
Expand All @@ -258,5 +260,4 @@ def test_decoding(backend, frontend, layer, analytic):

data = random_tensor(frontend, (100, 32))
target = prepare_targets(frontend, model, data)
print("> Training the hybrid classical-quantum model...")
backprop_test(frontend, model, data, target)

0 comments on commit e4e62c1

Please sign in to comment.