Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyO3: Add equal and __richcmp__ to candle.Tensor #1099

Merged

Conversation

LLukas22
Copy link
Contributor

This PR adds support for all equality operations between tensors and scalars/tensors. This also has the same broadcasting behavior as pytorch/numpy meaning the rhs and lhs will be automatically broadcasted into the right shapes for the comparison.

@LaurentMazare
Copy link
Collaborator

Thanks, I have mixed feelings about this one. Providing an equality comparison is really tricky, e.g. what should it do for tensors located on different devices, for tensor with the same values but different types, ...? I don't think there are good answers to these. Do you know of specific use cases where these would be useful?

@LLukas22
Copy link
Contributor Author

I based the equals method of torch.equal, the documentation isn't very explicit on how it should work. I would mainly use it for unit-tests as it currently isn't possible to compare tensors with 4+ dimensions as we can only export tensors as python lists up to 3 dimensions. If you don't like it we can just discard it.

@LaurentMazare
Copy link
Collaborator

Yeah I think it indeed makes sense for test but instead I would rather have some function that are just on the test side and can be pure python, e.g. test_equal and test_almost_equal (not sure tests would need more comparisons than this and the second one is not covered here).

@LLukas22
Copy link
Contributor Author

Alright i will create a testing module tomorrow that contains these two methods. I will probably use the __richcmp__ methodes to implement them, i guess that should be fine?

@LaurentMazare
Copy link
Collaborator

I think that we could go for something simpler than __richcmp__ as we only use it for a single operation rather than all of them. Like assert_equal and assert_almost_equal could both be two lines of python, just doing the difference, then abs, then reduce_mean and that's about it. Better not to introduce too much code that we would have to maintain over time if that sounds reasonable.

@LLukas22
Copy link
Contributor Author

Like assert_equal and assert_almost_equal could both be two lines of python, just doing the difference, then abs, then reduce_mean and that's about it.

You are probably right about this. Just so we don't mix things up here, i first added __richcmp__ to, well, enable tensors to be rich compared. I then added equal and implemented it via __richcmp__, because it already was implemented.

So i guess what we want to do is implement some simple assert_equal and assert_almost_equal functions and let __richcmp__ just do the rich comparing ( e.g. a == b , a > b etc. )?

@LaurentMazare
Copy link
Collaborator

Oh I think we just want assert_equal and assert_almost_equal and don't do any __richcmp__ until we find really good use cases for having it (I would think that it's almost always a bad idea as it means one is doing branching code so the computation graph is not generated properly, gradients are weird etc, and I wouldn't know of any model where this is actually used).

@LLukas22
Copy link
Contributor Author

I mainly added __richcmp__ support for this line in modelling_llama.py:

mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

They basically create their causal mask via a rich compare and i agree it's not what i would call "easily readable code" but i added the __richcmp___ feature to keep the "torch compatibility". Don't know if we want to keep it or remove it. That's your decision to make.

@LaurentMazare
Copy link
Collaborator

Ah yeah that looks like a potential use case that people coming from PyTorch may want to replicate. Maybe it's good to have in that case - I'm a bit on the fence. How do you feel about it? Actually I'm happy either way the added code is not too large neither and could easily be tested.

@LLukas22
Copy link
Contributor Author

Alright, here is what i would do:

  1. Keep ___richcmp___ in its current form.
  2. Remove tensor.equal
  3. Create candle.testing module with assert_equal and assert_almost_equal methodes, implemented in pure python.

I see no reason to remove ___richcmp___ as it basically only maps the comparison methodes to the matching tensor.cmp calls.
Is that acceptable to you?

@LaurentMazare
Copy link
Collaborator

Yeah that sounds reasonable then. Do you want to make a new PR for __richcmp__ or use this one?

@LLukas22
Copy link
Contributor Author

Im gonna use this one for __richcmp__ and add the candle.testing module also in this one as i need it to unit-test the __richcmp__ operations properly.

@LaurentMazare
Copy link
Collaborator

Great, let me know when it's good to look at.

@LLukas22
Copy link
Contributor Author

Alright, should be ready, but I'm not happy how assert_almost_equal turned out. I also noticed that u32 tensors can overflow into the negative, without throwing an Error.

@LLukas22
Copy link
Contributor Author

On a second thought, u32 comparisons via assert_almost_equal are probably still broken if i have a=[1, 100] and b=[100, 1] tensor as substracting those will cause an overflow in one of them.

@LaurentMazare
Copy link
Collaborator

On a second thought, u32 comparisons via assert_almost_equal are probably still broken if i have a=[1, 100] and b=[100, 1] tensor as substracting those will cause an overflow in one of them.

A bit hacky but you can take the difference of max(lhs, rhs) and min(lhs, rhs), that should always be positive. Or I wouldn't mind having codepaths that are different based on the type and say cast all integer types to i64.
Fwiw the silent wrong results when overflowing are a rust behavior when in release mode, debug mode shoud properly report an error but obviously this has a runtime cost.

@LLukas22
Copy link
Contributor Author

A bit hacky but you can take the difference of max(lhs, rhs) and min(lhs, rhs), that should always be positive. Or I wouldn't mind having codepaths that are different based on the type and say cast all integer types to i64.

I tried to implement it by casting my unsigned types to i64 but i64 tensors don't implement abs which means i cant easily compute the absolute difference between them. I also noticed that i64 got mapped to the candle.i16 member and fixed that.

@LaurentMazare
Copy link
Collaborator

The pyo3 CI seems to fail, mind taking a look?

@LLukas22
Copy link
Contributor Author

Yeah it currently fails because abs is not implemented for i64, don't know if it should be implemented in candle-core or how i could implement the absolute difference between two tensors another way.

@LaurentMazare
Copy link
Collaborator

Ah right, just merged #1216 that adds the abs function for integer types, let me know if this doesn't work on your side.

@LLukas22
Copy link
Contributor Author

Ah right, just merged #1216 that adds the abs function for integer types, let me know if this doesn't work on your side.

Ah, nice. Yeah that should fix the problem. I'm gonna finish this PR tomorrow at work 👍
You are making incredibly fast progress with your implementations, but please also ensure to take care of yourself and avoid overdoing it❤️.

@LLukas22
Copy link
Contributor Author

Alright, should be good to go 👍

@LaurentMazare
Copy link
Collaborator

Great, thanks!

@LaurentMazare LaurentMazare merged commit c05c0a8 into huggingface:main Oct 30, 2023
22 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants