Skip to content

Commit

Permalink
fix: some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Oct 22, 2024
1 parent a56a671 commit 123f053
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 311 deletions.
41 changes: 0 additions & 41 deletions src/qiboml/models/#ansatze.py#

This file was deleted.

24 changes: 0 additions & 24 deletions src/qiboml/models/_pytorch.py

This file was deleted.

22 changes: 0 additions & 22 deletions src/qiboml/models/_utils.py

This file was deleted.

63 changes: 0 additions & 63 deletions src/qiboml/models/abstract.py

This file was deleted.

6 changes: 5 additions & 1 deletion src/qiboml/models/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def circuit(
def set_backend(self, backend):
self.backend = backend

@property
def output_shape(self):
raise_error(NotImplementedError)


@dataclass
class Probabilities(QuantumDecoding):
Expand Down Expand Up @@ -102,7 +106,7 @@ def output_shape(self):
class Samples(QuantumDecoding):

def forward(self, x: Circuit) -> ndarray:
return self.backend.cast(super().__call__(x).samples(), dtype=np.float64)
return self.backend.cast(super().__call__(x).samples(), self.backend.precision)

@property
def output_shape(self):
Expand Down
158 changes: 0 additions & 158 deletions src/qiboml/models/encoding_decoding.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/qiboml/operations/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def evaluate(self, x: ndarray, encoding, training, decoding, backend, *parameter
if binary:
gradients = (
self._jax.numpy.zeros((decoding.output_shape[-1], x.shape[-1])),
self._jacobian_without_inputs(*parameters),
self._jacobian_without_inputs(*parameters), # pylint: disable=no-member
)
else:
gradients = self._jacobian(x, *parameters)
gradients = self._jacobian(x, *parameters) # pylint: disable=no-member
decoding.set_backend(backend)
return [
backend.cast(self._jax.to_numpy(grad).tolist(), backend.precision)
Expand Down

0 comments on commit 123f053

Please sign in to comment.