[PROPOSAL]: Addition of a SIMD function wrapper like jax.pmap #3698
BigBalloon8
started this conversation in
Development | Core
Replies: 1 comment
-
Hi @mayfieldmobster Thanks so much for your contribution! We will review it later. Thanks. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Proposal:
see issue #3651
Motivation
Colossal AI provides lots of ways to easily parallelize the training and deployment of a model (using data parallelism, pipeline parallelism, etc.) but training large models is not the only thing in ML and data science that could benefit from parallelization.
jax.pmap provides an easy way to parallelize your code across multiple devices. torch2.0 introduced torch.vmap which added single-device vectorization for SIMD (single input multiple data) functions like jax.vmap but nothing like jax.pmap.
Design
I propose cmap (colossal ai map).
cmap map would act as a Python wrapper around your function. It would have all the same args and kwargs as torch.vmap with the addition of a few more such as the output destination and the process group
Possible sample code
Possible cmap code (pseudo-code)
Self-Service
Beta Was this translation helpful? Give feedback.
All reactions