diff --git a/tests/test_models_interfaces.py b/tests/test_models_interfaces.py index aaa40ab..54e2df2 100644 --- a/tests/test_models_interfaces.py +++ b/tests/test_models_interfaces.py @@ -4,13 +4,14 @@ import numpy as np import pytest import torch -from qibo import construct_backend, hamiltonians +from qibo import hamiltonians from qibo.config import raise_error from qibo.symbols import Z import qiboml.models.ansatze as ans import qiboml.models.decoding as dec import qiboml.models.encoding as enc +from qiboml.operations.differentiation import PSR torch.set_default_dtype(torch.float64) @@ -228,6 +229,10 @@ def test_decoding(backend, frontend, layer, seed, analytic): pytest.skip("Non pytorch/jax differentiation is not working yet.") if analytic and not layer is dec.Expectation: pytest.skip("Unused analytic argument.") + if not analytic and not layer is dec.Expectation: + pytest.skip( + "Expectation layer is the only differentiable decoding when the diffrule is not analytical." + ) set_seed(frontend, seed) @@ -251,13 +256,19 @@ def test_decoding(backend, frontend, layer, seed, analytic): ) kwargs["observable"] = observable kwargs["analytic"] = analytic + if not analytic: + differentiation_rule = PSR() + else: + differentiation_rule = None kwargs["nshots"] = None decoding_layer = layer(nqubits, decoding_qubits, **kwargs) # if not decoding_layer.analytic: # pytest.skip("PSR differentiation is not working yet.") - q_model = frontend.QuantumModel(encoding_layer, training_layer, decoding_layer) + q_model = frontend.QuantumModel( + encoding_layer, training_layer, decoding_layer, differentiation_rule + ) data = random_tensor(frontend, (100, dim)) target = prepare_targets(frontend, q_model, data)