TorchDynamo and C++ extensions #197
RaulPPelaez
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
We have a torch C++ Autograd extension (called
get_neighbor_pairs
) which provides a CPU and CUDA backends.When trying to torch.compile a model whose forward function calls this extension:
I get the following error:
Click me
The extension is defined as:
torchmd-net/torchmdnet/neighbors/neighbors.cpp
Lines 1 to 5 in a116847
And the CUDA implementation is here
torchmd-net/torchmdnet/neighbors/neighbors_cuda.cu
Lines 74 to 89 in a116847
I jit compile this function when import happens:
torchmd-net/torchmdnet/neighbors/__init__.py
Lines 2 to 16 in a116847
Eventually I call this extension in the forward function of the model in the example above:
torchmd-net/torchmdnet/models/utils.py
Line 110 in a116847
Which stores the extension function as
self.kernel
torchmd-net/torchmdnet/models/utils.py
Lines 234 to 237 in a116847
Beta Was this translation helpful? Give feedback.
All reactions