From 45498bc2ab8045c25a345fc465cd83d4c2ae6551 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 15 Oct 2023 13:45:06 +0200 Subject: [PATCH 1/9] add `equal` to tensor --- candle-pyo3/py_src/candle/__init__.pyi | 5 +++++ candle-pyo3/src/lib.rs | 21 +++++++++++++++++++++ candle-pyo3/tests/native/test_tensor.py | 12 ++++++++++++ 3 files changed, 38 insertions(+) diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 414f0bc44c..7c7d240747 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -191,6 +191,11 @@ class Tensor: Gets the tensor's dtype. """ pass + def equal(self, rhs: Tensor) -> bool: + """ + True if two tensors have the same size and elements, False otherwise. + """ + pass def exp(self) -> Tensor: """ Performs the `exp` operation on the tensor. diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 02db05e568..37b4c31698 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -592,6 +592,27 @@ impl PyTensor { Ok(Self(tensor)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// True if two tensors have the same size and elements, False otherwise. + /// &RETURNS&: bool + fn equal(&self, rhs: &Self) -> PyResult { + if self.0.shape() != rhs.0.shape() { + return Ok(false); + } + let result = self + .0 + .eq(&rhs.0) + .map_err(wrap_err)? + .to_dtype(DType::I64) + .map_err(wrap_err)?; + Ok(result + .sum_all() + .map_err(wrap_err)? + .to_scalar::() + .map_err(wrap_err)? + == self.0.elem_count() as i64) + } + #[pyo3(text_signature = "(self, shape:Sequence[int])")] /// Reshapes the tensor to the given shape. /// &RETURNS&: Tensor diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 1f5b74f677..309f6027fc 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -72,3 +72,15 @@ def test_tensor_can_be_scliced_3d(): assert t[:, 0, 0].values() == [1, 9] assert t[..., 0].values() == [[1, 5], [9, 13]] assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] + + +def test_tensors_can_be_compared_with_equal(): + t = Tensor(42.0) + other = Tensor(42.0) + assert t.equal(other) + t = Tensor([42.0, 42.1]) + other = Tensor([42.0, 42.0]) + assert not t.equal(other) + t = Tensor(42.0) + other = Tensor([42.0, 42.0]) + assert not t.equal(other) From 08a7fe5123ac9f89ef1bdffd8460d8c880efc156 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 15 Oct 2023 13:58:45 +0200 Subject: [PATCH 2/9] add `__richcmp__` support for tensors and scalars --- candle-pyo3/src/lib.rs | 55 +++++++++++++++++++++ candle-pyo3/src/utils.rs | 40 +++++++++++++++ candle-pyo3/tests/native/test_tensor.py | 65 +++++++++++++++++++++++++ 3 files changed, 160 insertions(+) create mode 100644 candle-pyo3/src/utils.rs diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 37b4c31698..d69d752d9c 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,8 +1,11 @@ #![allow(clippy::redundant_closure_call)] use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; +use pyo3::pyclass::CompareOp; use pyo3::types::{IntoPyDict, PyDict, PyTuple}; use pyo3::ToPyObject; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::os::raw::c_long; use std::sync::Arc; @@ -10,6 +13,10 @@ use half::{bf16, f16}; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; +mod utils; + +use utils::broadcast_shapes; + pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } @@ -591,6 +598,54 @@ impl PyTensor { }; Ok(Self(tensor)) } + /// Rich-compare two tensors. + /// &RETURNS&: Tensor + fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult { + let compare = |lhs: &Tensor, rhs: &Tensor| { + let t = match op { + CompareOp::Eq => lhs.eq(rhs), + CompareOp::Ne => lhs.ne(rhs), + CompareOp::Lt => lhs.lt(rhs), + CompareOp::Le => lhs.le(rhs), + CompareOp::Gt => lhs.gt(rhs), + CompareOp::Ge => lhs.ge(rhs), + }; + Ok(PyTensor(t.map_err(wrap_err)?)) + }; + if let Ok(rhs) = rhs.extract::() { + if self.0.shape() == rhs.0.shape() { + compare(&self.0, &rhs.0) + } else { + // We broadcast manually here because as candle.cmp does not support broadcasting + let broadcast_shape = broadcast_shapes(&self.0, &rhs.0).map_err(wrap_err)?; + let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; + let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; + + compare(&broadcasted_lhs, &broadcasted_rhs) + } + } else if let Ok(rhs) = rhs.extract::() { + let scalar_tensor = Tensor::new(rhs, self.0.device()) + .map_err(wrap_err)? + .to_dtype(self.0.dtype()) + .map_err(wrap_err)? + .broadcast_as(self.0.shape()) + .map_err(wrap_err)?; + + compare(&self.0, &scalar_tensor) + } else { + return Err(PyTypeError::new_err("unsupported rhs for __richcmp__")); + } + } + + fn __hash__(&self) -> u64 { + // we havbe overridden __richcmp__ => py03 wants us to also override __hash__ + // we simply hash the address of the tensor + let mut hasher = DefaultHasher::new(); + let pointer = &self.0 as *const Tensor; + let address = pointer as usize; + address.hash(&mut hasher); + hasher.finish() + } #[pyo3(text_signature = "(self, rhs:Tensor)")] /// True if two tensors have the same size and elements, False otherwise. diff --git a/candle-pyo3/src/utils.rs b/candle-pyo3/src/utils.rs new file mode 100644 index 0000000000..5fd1df3137 --- /dev/null +++ b/candle-pyo3/src/utils.rs @@ -0,0 +1,40 @@ +use ::candle::{Error as CandleError, Result as CandleResult}; +use candle::Shape; + +/// Tries to broadcast the `rhs` shape into the `lhs` shape. +pub fn broadcast_shapes(lhs: &::candle::Tensor, rhs: &::candle::Tensor) -> CandleResult { + // see `Shape.broadcast_shape_binary_op` + let lhs_dims = lhs.dims(); + let rhs_dims = rhs.dims(); + let lhs_ndims = lhs_dims.len(); + let rhs_ndims = rhs_dims.len(); + let bcast_ndims = usize::max(lhs_ndims, rhs_ndims); + let mut bcast_dims = vec![0; bcast_ndims]; + for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() { + let rev_idx = bcast_ndims - idx; + let l_value = if lhs_ndims < rev_idx { + 1 + } else { + lhs_dims[lhs_ndims - rev_idx] + }; + let r_value = if rhs_ndims < rev_idx { + 1 + } else { + rhs_dims[rhs_ndims - rev_idx] + }; + *bcast_value = if l_value == r_value { + l_value + } else if l_value == 1 { + r_value + } else if r_value == 1 { + l_value + } else { + return Err(CandleError::BroadcastIncompatibleShapes { + src_shape: lhs.shape().clone(), + dst_shape: rhs.shape().clone(), + } + .bt()); + } + } + Ok(Shape::from(bcast_dims)) +} diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 309f6027fc..ef6e87c6c5 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -84,3 +84,68 @@ def test_tensors_can_be_compared_with_equal(): t = Tensor(42.0) other = Tensor([42.0, 42.0]) assert not t.equal(other) + + +def test_tensor_supports_equality_opperations_with_scalars(): + t = Tensor(42.0) + assert (t == 42.0).equal(Tensor(1).to_dtype(candle.u8)) + assert (t == 43.0).equal(Tensor(0).to_dtype(candle.u8)) + + assert (t != 42.0).equal(Tensor(0).to_dtype(candle.u8)) + assert (t != 43.0).equal(Tensor(1).to_dtype(candle.u8)) + + assert (t > 41.0).equal(Tensor(1).to_dtype(candle.u8)) + assert (t > 42.0).equal(Tensor(0).to_dtype(candle.u8)) + + assert (t >= 42.0).equal(Tensor(1).to_dtype(candle.u8)) + assert (t >= 43.0).equal(Tensor(0).to_dtype(candle.u8)) + + assert (t < 43.0).equal(Tensor(1).to_dtype(candle.u8)) + assert (t < 42.0).equal(Tensor(0).to_dtype(candle.u8)) + + assert (t <= 42.0).equal(Tensor(1).to_dtype(candle.u8)) + assert (t <= 41.0).equal(Tensor(0).to_dtype(candle.u8)) + + +def test_tensor_supports_equality_opperations_with_tensors(): + t = Tensor(42.0) + same = Tensor(42.0) + other = Tensor(43.0) + + assert (t == same).equal(Tensor(1).to_dtype(candle.u8)) + assert (t == other).equal(Tensor(0).to_dtype(candle.u8)) + + assert (t != same).equal(Tensor(0).to_dtype(candle.u8)) + assert (t != other).equal(Tensor(1).to_dtype(candle.u8)) + + assert (t > same).equal(Tensor(0).to_dtype(candle.u8)) + assert (t > other).equal(Tensor(0).to_dtype(candle.u8)) + + assert (t >= same).equal(Tensor(1).to_dtype(candle.u8)) + assert (t >= other).equal(Tensor(0).to_dtype(candle.u8)) + + assert (t < same).equal(Tensor(0).to_dtype(candle.u8)) + assert (t < other).equal(Tensor(1).to_dtype(candle.u8)) + + assert (t <= same).equal(Tensor(1).to_dtype(candle.u8)) + assert (t <= other).equal(Tensor(1).to_dtype(candle.u8)) + + +def test_tensor_equality_opperations_can_broadcast(): + # Create a decoder attention mask as a test case + # e.g. + # [[1,0,0] + # [1,1,0] + # [1,1,1]] + mask_cond = candle.Tensor([0, 1, 2]) + mask = mask_cond < (mask_cond + 1).reshape((3, 1)) + assert mask.shape == (3, 3) + assert mask.equal(Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8)) + + +def test_tensor_can_be_hashed(): + t = Tensor(42.0) + other = Tensor(42.0) + # Hash should represent the a unique tensor + assert hash(t) != hash(other) + assert hash(t) == hash(t) From b1d64d6a7371429ae371b21ce2b54d6d3d756870 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 15 Oct 2023 14:01:50 +0200 Subject: [PATCH 3/9] typo --- candle-pyo3/tests/native/test_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index ef6e87c6c5..ef7ead0803 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -146,6 +146,6 @@ def test_tensor_equality_opperations_can_broadcast(): def test_tensor_can_be_hashed(): t = Tensor(42.0) other = Tensor(42.0) - # Hash should represent the a unique tensor + # Hash should represent a unique tensor assert hash(t) != hash(other) assert hash(t) == hash(t) From 4fb19fa905a94414ffd5443f8289eb88df2c4c92 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Fri, 20 Oct 2023 11:15:10 +0200 Subject: [PATCH 4/9] more typos --- candle-pyo3/src/lib.rs | 4 ++-- candle-pyo3/tests/native/test_tensor.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 188b646924..55c3882a51 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -678,7 +678,7 @@ impl PyTensor { if self.0.shape() == rhs.0.shape() { compare(&self.0, &rhs.0) } else { - // We broadcast manually here because as candle.cmp does not support broadcasting + // We broadcast manually here because `candle.cmp` does not support automatic broadcasting let broadcast_shape = broadcast_shapes(&self.0, &rhs.0).map_err(wrap_err)?; let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; @@ -700,7 +700,7 @@ impl PyTensor { } fn __hash__(&self) -> u64 { - // we havbe overridden __richcmp__ => py03 wants us to also override __hash__ + // we have overridden __richcmp__ => py03 wants us to also override __hash__ // we simply hash the address of the tensor let mut hasher = DefaultHasher::new(); let pointer = &self.0 as *const Tensor; diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 2cef58f430..06f7e01ddf 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -152,6 +152,8 @@ def test_tensor_can_be_hashed(): # Hash should represent a unique tensor assert hash(t) != hash(other) assert hash(t) == hash(t) + + def test_tensor_can_be_expanded_with_none(): t = candle.rand((12, 12)) From dd06ff85a51d3466dfa09409bc9ccfb2d357e455 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Fri, 27 Oct 2023 21:32:16 +0200 Subject: [PATCH 5/9] Add `abs` + `candle.testing` --- candle-pyo3/_additional_typing/__init__.py | 36 ++++++++++ candle-pyo3/py_src/candle/__init__.pyi | 46 +++++++++++-- candle-pyo3/py_src/candle/testing/__init__.py | 66 +++++++++++++++++++ candle-pyo3/src/lib.rs | 38 +++++------ candle-pyo3/tests/bindings/test_testing.py | 33 ++++++++++ candle-pyo3/tests/native/test_tensor.py | 66 +++++++++---------- 6 files changed, 224 insertions(+), 61 deletions(-) create mode 100644 candle-pyo3/py_src/candle/testing/__init__.py create mode 100644 candle-pyo3/tests/bindings/test_testing.py diff --git a/candle-pyo3/_additional_typing/__init__.py b/candle-pyo3/_additional_typing/__init__.py index 0d0eec90eb..7bc17ee154 100644 --- a/candle-pyo3/_additional_typing/__init__.py +++ b/candle-pyo3/_additional_typing/__init__.py @@ -53,3 +53,39 @@ def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Te Return a slice of a tensor. """ pass + + def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 45ebca847a..db07dfb315 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -124,16 +124,46 @@ class Tensor: Add a scalar to a tensor or two tensors together. """ pass + def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor": """ Return a slice of a tensor. """ pass + def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Multiply a tensor by a scalar or one tensor by another. """ pass + def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Add a scalar to a tensor or two tensors together. @@ -159,6 +189,11 @@ class Tensor: Divide a tensor by a scalar or one tensor by another. """ pass + def abs(self) -> Tensor: + """ + Performs the `abs` operation on the tensor. + """ + pass def argmax_keepdim(self, dim: int) -> Tensor: """ Returns the indices of the maximum value(s) across the selected dimension. @@ -231,11 +266,6 @@ class Tensor: Gets the tensor's dtype. """ pass - def equal(self, rhs: Tensor) -> bool: - """ - True if two tensors have the same size and elements, False otherwise. - """ - pass def exp(self) -> Tensor: """ Performs the `exp` operation on the tensor. @@ -313,6 +343,12 @@ class Tensor: ranges from `start` to `start + len`. """ pass + @property + def nelements(self) -> int: + """ + Gets the tensor's element count. + """ + pass def powf(self, p: float) -> Tensor: """ Performs the `pow` operation on the tensor with the given exponent. diff --git a/candle-pyo3/py_src/candle/testing/__init__.py b/candle-pyo3/py_src/candle/testing/__init__.py new file mode 100644 index 0000000000..a11b56e252 --- /dev/null +++ b/candle-pyo3/py_src/candle/testing/__init__.py @@ -0,0 +1,66 @@ +import candle +from candle import Tensor + + +def _assert_tensor_metadata( + actual: Tensor, + expected: Tensor, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + if check_device: + assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}" + + if check_dtype: + assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}" + + if check_layout: + assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}" + + if check_stride: + assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}" + + +def assert_equal( + actual: Tensor, + expected: Tensor, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + """ + Asserts that two tensors are exact equals. + """ + _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride) + assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}" + + +def assert_almost_equal( + actual: Tensor, + expected: Tensor, + rtol=1e-05, + atol=1e-08, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + """ + Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance. + + Computes: |actual - expected| ≤ atol + rtol x |expected| + """ + _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride) + + # Secure against overflow of u32 and u8 tensors + diff = ( + (actual - expected).abs() + if actual.sum_all().values() > expected.sum_all().values() + else (expected - actual).abs() + ) + threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected) + + assert (diff <= threshold).sum_all().values() == actual.nelements, f"Difference between tensors was to great" diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 9a09142071..40c908806e 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -337,6 +337,13 @@ impl PyTensor { PyTuple::new(py, self.0.dims()).to_object(py) } + #[getter] + /// Gets the tensor's element count. + /// &RETURNS&: int + fn nelements(&self) -> usize { + self.0.elem_count() + } + #[getter] /// Gets the tensor's strides. /// &RETURNS&: Tuple[int] @@ -373,6 +380,16 @@ impl PyTensor { self.__repr__() } + /// Performs the `abs` operation on the tensor. + /// &RETURNS&: Tensor + fn abs(&self) -> PyResult { + match self.0.dtype() { + DType::U8 => Ok(PyTensor(self.0.clone())), + DType::U32 => Ok(PyTensor(self.0.clone())), + _ => Ok(PyTensor(self.0.abs().map_err(wrap_err)?)), + } + } + /// Performs the `sin` operation on the tensor. /// &RETURNS&: Tensor fn sin(&self) -> PyResult { @@ -739,27 +756,6 @@ impl PyTensor { hasher.finish() } - #[pyo3(text_signature = "(self, rhs:Tensor)")] - /// True if two tensors have the same size and elements, False otherwise. - /// &RETURNS&: bool - fn equal(&self, rhs: &Self) -> PyResult { - if self.0.shape() != rhs.0.shape() { - return Ok(false); - } - let result = self - .0 - .eq(&rhs.0) - .map_err(wrap_err)? - .to_dtype(DType::I64) - .map_err(wrap_err)?; - Ok(result - .sum_all() - .map_err(wrap_err)? - .to_scalar::() - .map_err(wrap_err)? - == self.0.elem_count() as i64) - } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] /// Reshapes the tensor to the given shape. /// &RETURNS&: Tensor diff --git a/candle-pyo3/tests/bindings/test_testing.py b/candle-pyo3/tests/bindings/test_testing.py new file mode 100644 index 0000000000..58a2ed1f51 --- /dev/null +++ b/candle-pyo3/tests/bindings/test_testing.py @@ -0,0 +1,33 @@ +import candle +from candle import Tensor +from candle.testing import assert_equal, assert_almost_equal +import pytest + + +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8]) +def test_assert_equal_asserts_correctly(dtype: candle.DType): + a = Tensor([1, 2, 3]).to(dtype) + b = Tensor([1, 2, 3]).to(dtype) + assert_equal(a, b) + + with pytest.raises(AssertionError): + assert_equal(a, b + 1) + + +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8]) +def test_assert_almost_equal_asserts_correctly(dtype: candle.DType): + a = Tensor([1, 2, 3]).to(dtype) + b = Tensor([1, 2, 3]).to(dtype) + assert_almost_equal(a, b) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1) + + assert_almost_equal(a, b + 1, atol=20) + assert_almost_equal(a, b + 1, rtol=20) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1, atol=0.9) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1, rtol=0.1) diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 06f7e01ddf..ef44fc4c93 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -1,6 +1,7 @@ import candle from candle import Tensor from candle.utils import cuda_is_available +from candle.testing import assert_equal import pytest @@ -77,37 +78,32 @@ def test_tensor_can_be_scliced_3d(): assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] -def test_tensors_can_be_compared_with_equal(): - t = Tensor(42.0) - other = Tensor(42.0) - assert t.equal(other) - t = Tensor([42.0, 42.1]) - other = Tensor([42.0, 42.0]) - assert not t.equal(other) - t = Tensor(42.0) - other = Tensor([42.0, 42.0]) - assert not t.equal(other) +def assert_bool(t: Tensor, expected: bool): + assert t.shape == () + assert str(t.dtype) == str(candle.u8) + assert bool(t.values()) == expected def test_tensor_supports_equality_opperations_with_scalars(): t = Tensor(42.0) - assert (t == 42.0).equal(Tensor(1).to_dtype(candle.u8)) - assert (t == 43.0).equal(Tensor(0).to_dtype(candle.u8)) - assert (t != 42.0).equal(Tensor(0).to_dtype(candle.u8)) - assert (t != 43.0).equal(Tensor(1).to_dtype(candle.u8)) + assert_bool(t == 42.0, True) + assert_bool(t == 43.0, False) + + assert_bool(t != 42.0, False) + assert_bool(t != 43.0, True) - assert (t > 41.0).equal(Tensor(1).to_dtype(candle.u8)) - assert (t > 42.0).equal(Tensor(0).to_dtype(candle.u8)) + assert_bool(t > 41.0, True) + assert_bool(t > 42.0, False) - assert (t >= 42.0).equal(Tensor(1).to_dtype(candle.u8)) - assert (t >= 43.0).equal(Tensor(0).to_dtype(candle.u8)) + assert_bool(t >= 41.0, True) + assert_bool(t >= 42.0, True) - assert (t < 43.0).equal(Tensor(1).to_dtype(candle.u8)) - assert (t < 42.0).equal(Tensor(0).to_dtype(candle.u8)) + assert_bool(t < 43.0, True) + assert_bool(t < 42.0, False) - assert (t <= 42.0).equal(Tensor(1).to_dtype(candle.u8)) - assert (t <= 41.0).equal(Tensor(0).to_dtype(candle.u8)) + assert_bool(t <= 43.0, True) + assert_bool(t <= 42.0, True) def test_tensor_supports_equality_opperations_with_tensors(): @@ -115,23 +111,23 @@ def test_tensor_supports_equality_opperations_with_tensors(): same = Tensor(42.0) other = Tensor(43.0) - assert (t == same).equal(Tensor(1).to_dtype(candle.u8)) - assert (t == other).equal(Tensor(0).to_dtype(candle.u8)) + assert_bool(t == same, True) + assert_bool(t == other, False) - assert (t != same).equal(Tensor(0).to_dtype(candle.u8)) - assert (t != other).equal(Tensor(1).to_dtype(candle.u8)) + assert_bool(t != same, False) + assert_bool(t != other, True) - assert (t > same).equal(Tensor(0).to_dtype(candle.u8)) - assert (t > other).equal(Tensor(0).to_dtype(candle.u8)) + assert_bool(t > same, False) + assert_bool(t > other, False) - assert (t >= same).equal(Tensor(1).to_dtype(candle.u8)) - assert (t >= other).equal(Tensor(0).to_dtype(candle.u8)) + assert_bool(t >= same, True) + assert_bool(t >= other, False) - assert (t < same).equal(Tensor(0).to_dtype(candle.u8)) - assert (t < other).equal(Tensor(1).to_dtype(candle.u8)) + assert_bool(t < same, False) + assert_bool(t < other, True) - assert (t <= same).equal(Tensor(1).to_dtype(candle.u8)) - assert (t <= other).equal(Tensor(1).to_dtype(candle.u8)) + assert_bool(t <= same, True) + assert_bool(t <= other, True) def test_tensor_equality_opperations_can_broadcast(): @@ -143,7 +139,7 @@ def test_tensor_equality_opperations_can_broadcast(): mask_cond = candle.Tensor([0, 1, 2]) mask = mask_cond < (mask_cond + 1).reshape((3, 1)) assert mask.shape == (3, 3) - assert mask.equal(Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8)) + assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8)) def test_tensor_can_be_hashed(): From b58056f5edbfb7609d066e63218d271b15d74b12 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sat, 28 Oct 2023 13:44:26 +0200 Subject: [PATCH 6/9] remove duplicated `broadcast_shape_binary_op` --- candle-core/src/shape.rs | 2 +- candle-pyo3/src/lib.rs | 10 +++++----- candle-pyo3/src/utils.rs | 40 ---------------------------------------- 3 files changed, 6 insertions(+), 46 deletions(-) delete mode 100644 candle-pyo3/src/utils.rs diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index ac00a97997..beaa945534 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -203,7 +203,7 @@ impl Shape { /// Check whether the two shapes are compatible for broadcast, and if it is the case return the /// broadcasted shape. This is to be used for binary pointwise ops. - pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result { + pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result { let lhs = self; let lhs_dims = lhs.dims(); let rhs_dims = rhs.dims(); diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 40c908806e..4b75a6871e 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -19,10 +19,6 @@ extern crate accelerate_src; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; -mod utils; - -use utils::broadcast_shapes; - pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } @@ -726,7 +722,11 @@ impl PyTensor { compare(&self.0, &rhs.0) } else { // We broadcast manually here because `candle.cmp` does not support automatic broadcasting - let broadcast_shape = broadcast_shapes(&self.0, &rhs.0).map_err(wrap_err)?; + let broadcast_shape = self + .0 + .shape() + .broadcast_shape_binary_op(rhs.0.shape(), "cmp") + .map_err(wrap_err)?; let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; diff --git a/candle-pyo3/src/utils.rs b/candle-pyo3/src/utils.rs deleted file mode 100644 index 5fd1df3137..0000000000 --- a/candle-pyo3/src/utils.rs +++ /dev/null @@ -1,40 +0,0 @@ -use ::candle::{Error as CandleError, Result as CandleResult}; -use candle::Shape; - -/// Tries to broadcast the `rhs` shape into the `lhs` shape. -pub fn broadcast_shapes(lhs: &::candle::Tensor, rhs: &::candle::Tensor) -> CandleResult { - // see `Shape.broadcast_shape_binary_op` - let lhs_dims = lhs.dims(); - let rhs_dims = rhs.dims(); - let lhs_ndims = lhs_dims.len(); - let rhs_ndims = rhs_dims.len(); - let bcast_ndims = usize::max(lhs_ndims, rhs_ndims); - let mut bcast_dims = vec![0; bcast_ndims]; - for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() { - let rev_idx = bcast_ndims - idx; - let l_value = if lhs_ndims < rev_idx { - 1 - } else { - lhs_dims[lhs_ndims - rev_idx] - }; - let r_value = if rhs_ndims < rev_idx { - 1 - } else { - rhs_dims[rhs_ndims - rev_idx] - }; - *bcast_value = if l_value == r_value { - l_value - } else if l_value == 1 { - r_value - } else if r_value == 1 { - l_value - } else { - return Err(CandleError::BroadcastIncompatibleShapes { - src_shape: lhs.shape().clone(), - dst_shape: rhs.shape().clone(), - } - .bt()); - } - } - Ok(Shape::from(bcast_dims)) -} From 43485693b115777cacb70599ade6277ccac85f74 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sat, 28 Oct 2023 14:31:40 +0200 Subject: [PATCH 7/9] `candle.i16` => `candle.i64` --- candle-pyo3/py_src/candle/testing/__init__.py | 14 +++++++++----- candle-pyo3/src/lib.rs | 5 +++-- candle-pyo3/tests/bindings/test_testing.py | 4 ++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/candle-pyo3/py_src/candle/testing/__init__.py b/candle-pyo3/py_src/candle/testing/__init__.py index a11b56e252..7b2dec9ec3 100644 --- a/candle-pyo3/py_src/candle/testing/__init__.py +++ b/candle-pyo3/py_src/candle/testing/__init__.py @@ -2,6 +2,9 @@ from candle import Tensor +_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)]) + + def _assert_tensor_metadata( actual: Tensor, expected: Tensor, @@ -56,11 +59,12 @@ def assert_almost_equal( _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride) # Secure against overflow of u32 and u8 tensors - diff = ( - (actual - expected).abs() - if actual.sum_all().values() > expected.sum_all().values() - else (expected - actual).abs() - ) + if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES: + actual = actual.to(candle.i64) + expected = expected.to(candle.i64) + + diff = (actual - expected).abs() + threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected) assert (diff <= threshold).sum_all().values() == actual.nelements, f"Difference between tensors was to great" diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 4b75a6871e..d1045b5bca 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -148,9 +148,10 @@ macro_rules! pydtype { } }; } + +pydtype!(i64, |v| v); pydtype!(u8, |v| v); pydtype!(u32, |v| v); -pydtype!(i64, |v| v); pydtype!(f16, f32::from); pydtype!(bf16, f32::from); pydtype!(f32, |v| v); @@ -1576,7 +1577,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add("u8", PyDType(DType::U8))?; m.add("u32", PyDType(DType::U32))?; - m.add("i16", PyDType(DType::I64))?; + m.add("i64", PyDType(DType::I64))?; m.add("bf16", PyDType(DType::BF16))?; m.add("f16", PyDType(DType::F16))?; m.add("f32", PyDType(DType::F32))?; diff --git a/candle-pyo3/tests/bindings/test_testing.py b/candle-pyo3/tests/bindings/test_testing.py index 58a2ed1f51..db2fd3f7fa 100644 --- a/candle-pyo3/tests/bindings/test_testing.py +++ b/candle-pyo3/tests/bindings/test_testing.py @@ -4,7 +4,7 @@ import pytest -@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8]) +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64]) def test_assert_equal_asserts_correctly(dtype: candle.DType): a = Tensor([1, 2, 3]).to(dtype) b = Tensor([1, 2, 3]).to(dtype) @@ -14,7 +14,7 @@ def test_assert_equal_asserts_correctly(dtype: candle.DType): assert_equal(a, b + 1) -@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8]) +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64]) def test_assert_almost_equal_asserts_correctly(dtype: candle.DType): a = Tensor([1, 2, 3]).to(dtype) b = Tensor([1, 2, 3]).to(dtype) From 86718144e352aab493a0196f464a567a9268f54e Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 29 Oct 2023 13:55:13 +0100 Subject: [PATCH 8/9] `tensor.nelements` -> `tensor.nelement` --- candle-pyo3/py_src/candle/__init__.pyi | 2 +- candle-pyo3/py_src/candle/testing/__init__.py | 2 +- candle-pyo3/src/lib.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index db07dfb315..48b1786c08 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -344,7 +344,7 @@ class Tensor: """ pass @property - def nelements(self) -> int: + def nelement(self) -> int: """ Gets the tensor's element count. """ diff --git a/candle-pyo3/py_src/candle/testing/__init__.py b/candle-pyo3/py_src/candle/testing/__init__.py index 7b2dec9ec3..240b635f28 100644 --- a/candle-pyo3/py_src/candle/testing/__init__.py +++ b/candle-pyo3/py_src/candle/testing/__init__.py @@ -67,4 +67,4 @@ def assert_almost_equal( threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected) - assert (diff <= threshold).sum_all().values() == actual.nelements, f"Difference between tensors was to great" + assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great" diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index d1045b5bca..614a2c0ff7 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -337,7 +337,7 @@ impl PyTensor { #[getter] /// Gets the tensor's element count. /// &RETURNS&: int - fn nelements(&self) -> usize { + fn nelement(&self) -> usize { self.0.elem_count() } From 09440ce900e5abca2d3ac67840f3bcdc681d50b7 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Mon, 30 Oct 2023 15:53:55 +0100 Subject: [PATCH 9/9] Cleanup `abs` --- candle-pyo3/src/lib.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 614da9dd1b..ddd58fbe16 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -367,11 +367,7 @@ impl PyTensor { /// Performs the `abs` operation on the tensor. /// &RETURNS&: Tensor fn abs(&self) -> PyResult { - match self.0.dtype() { - DType::U8 => Ok(PyTensor(self.0.clone())), - DType::U32 => Ok(PyTensor(self.0.clone())), - _ => Ok(PyTensor(self.0.abs().map_err(wrap_err)?)), - } + Ok(PyTensor(self.0.abs().map_err(wrap_err)?)) } /// Performs the `sin` operation on the tensor.