Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyO3: Add equal and __richcmp__ to candle.Tensor #1099

Merged
Merged
36 changes: 36 additions & 0 deletions candle-pyo3/_additional_typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,39 @@ def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Te
Return a slice of a tensor.
"""
pass

def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
41 changes: 41 additions & 0 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,46 @@ class Tensor:
Add a scalar to a tensor or two tensors together.
"""
pass
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
Expand All @@ -159,6 +189,11 @@ class Tensor:
Divide a tensor by a scalar or one tensor by another.
"""
pass
def abs(self) -> Tensor:
"""
Performs the `abs` operation on the tensor.
"""
pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
Expand Down Expand Up @@ -308,6 +343,12 @@ class Tensor:
ranges from `start` to `start + len`.
"""
pass
@property
def nelements(self) -> int:
"""
Gets the tensor's element count.
"""
pass
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.
Expand Down
66 changes: 66 additions & 0 deletions candle-pyo3/py_src/candle/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import candle
from candle import Tensor


def _assert_tensor_metadata(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
if check_device:
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"

if check_dtype:
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"

if check_layout:
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"

if check_stride:
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"


def assert_equal(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts that two tensors are exact equals.
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"


def assert_almost_equal(
actual: Tensor,
expected: Tensor,
rtol=1e-05,
atol=1e-08,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.

Computes: |actual - expected| ≤ atol + rtol x |expected|
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)

# Secure against overflow of u32 and u8 tensors
diff = (
(actual - expected).abs()
if actual.sum_all().values() > expected.sum_all().values()
else (expected - actual).abs()
)
threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)

assert (diff <= threshold).sum_all().values() == actual.nelements, f"Difference between tensors was to great"
72 changes: 72 additions & 0 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#![allow(clippy::redundant_closure_call)]
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;

Expand All @@ -16,6 +19,10 @@ extern crate accelerate_src;

use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};

mod utils;

use utils::broadcast_shapes;

pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
Expand Down Expand Up @@ -330,6 +337,13 @@ impl PyTensor {
PyTuple::new(py, self.0.dims()).to_object(py)
}

#[getter]
/// Gets the tensor's element count.
/// &RETURNS&: int
fn nelements(&self) -> usize {
LLukas22 marked this conversation as resolved.
Show resolved Hide resolved
self.0.elem_count()
}

#[getter]
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
Expand Down Expand Up @@ -366,6 +380,16 @@ impl PyTensor {
self.__repr__()
}

/// Performs the `abs` operation on the tensor.
/// &RETURNS&: Tensor
fn abs(&self) -> PyResult<Self> {
match self.0.dtype() {
DType::U8 => Ok(PyTensor(self.0.clone())),
DType::U32 => Ok(PyTensor(self.0.clone())),
_ => Ok(PyTensor(self.0.abs().map_err(wrap_err)?)),
}
}

/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
Expand Down Expand Up @@ -683,6 +707,54 @@ impl PyTensor {
};
Ok(Self(tensor))
}
/// Rich-compare two tensors.
/// &RETURNS&: Tensor
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
let compare = |lhs: &Tensor, rhs: &Tensor| {
let t = match op {
CompareOp::Eq => lhs.eq(rhs),
CompareOp::Ne => lhs.ne(rhs),
CompareOp::Lt => lhs.lt(rhs),
CompareOp::Le => lhs.le(rhs),
CompareOp::Gt => lhs.gt(rhs),
CompareOp::Ge => lhs.ge(rhs),
};
Ok(PyTensor(t.map_err(wrap_err)?))
};
if let Ok(rhs) = rhs.extract::<PyTensor>() {
if self.0.shape() == rhs.0.shape() {
compare(&self.0, &rhs.0)
} else {
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
let broadcast_shape = broadcast_shapes(&self.0, &rhs.0).map_err(wrap_err)?;
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;

compare(&broadcasted_lhs, &broadcasted_rhs)
}
} else if let Ok(rhs) = rhs.extract::<f64>() {
let scalar_tensor = Tensor::new(rhs, self.0.device())
.map_err(wrap_err)?
.to_dtype(self.0.dtype())
.map_err(wrap_err)?
.broadcast_as(self.0.shape())
.map_err(wrap_err)?;

compare(&self.0, &scalar_tensor)
} else {
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
}
}

fn __hash__(&self) -> u64 {
// we have overridden __richcmp__ => py03 wants us to also override __hash__
// we simply hash the address of the tensor
let mut hasher = DefaultHasher::new();
let pointer = &self.0 as *const Tensor;
let address = pointer as usize;
address.hash(&mut hasher);
hasher.finish()
}

#[pyo3(text_signature = "(self, shape:Sequence[int])")]
/// Reshapes the tensor to the given shape.
Expand Down
40 changes: 40 additions & 0 deletions candle-pyo3/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use ::candle::{Error as CandleError, Result as CandleResult};
use candle::Shape;

/// Tries to broadcast the `rhs` shape into the `lhs` shape.
pub fn broadcast_shapes(lhs: &::candle::Tensor, rhs: &::candle::Tensor) -> CandleResult<Shape> {
LLukas22 marked this conversation as resolved.
Show resolved Hide resolved
// see `Shape.broadcast_shape_binary_op`
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
let lhs_ndims = lhs_dims.len();
let rhs_ndims = rhs_dims.len();
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
let mut bcast_dims = vec![0; bcast_ndims];
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
let rev_idx = bcast_ndims - idx;
let l_value = if lhs_ndims < rev_idx {
1
} else {
lhs_dims[lhs_ndims - rev_idx]
};
let r_value = if rhs_ndims < rev_idx {
1
} else {
rhs_dims[rhs_ndims - rev_idx]
};
*bcast_value = if l_value == r_value {
l_value
} else if l_value == 1 {
r_value
} else if r_value == 1 {
l_value
} else {
return Err(CandleError::BroadcastIncompatibleShapes {
src_shape: lhs.shape().clone(),
dst_shape: rhs.shape().clone(),
}
.bt());
}
}
Ok(Shape::from(bcast_dims))
}
33 changes: 33 additions & 0 deletions candle-pyo3/tests/bindings/test_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import candle
from candle import Tensor
from candle.testing import assert_equal, assert_almost_equal
import pytest


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8])
def test_assert_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_equal(a, b)

with pytest.raises(AssertionError):
assert_equal(a, b + 1)


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8])
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_almost_equal(a, b)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1)

assert_almost_equal(a, b + 1, atol=20)
assert_almost_equal(a, b + 1, rtol=20)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, atol=0.9)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, rtol=0.1)
Loading
Loading