-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph Pass: scaled_dot_product_attention_sliced_q #2418
base: main
Are you sure you want to change the base?
Conversation
Great! Could you also add some concrete numbers in the PR description? For example, the memory and execution time of a model before/after using the slice Q algorithm? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code LGTM. For the CI, could you pin peft==0.13.2
to see if you could get a green CI?
coremltools/converters/mil/mil/passes/defs/scaled_dot_product_attention_sliced_q.py
Show resolved
Hide resolved
62daa78
to
05c5fc3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! The perf improvement (especially the memory save) looks really promising!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice graph pass!
(Previously I thought the parallelization is done in kernel 😂 looks like we need to give them hint haha)
@@ -44,5 +44,6 @@ | |||
optimize_state, | |||
optimize_tensor_operation, | |||
preprocess, | |||
scaled_dot_product_attention_sliced_q, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we name the source file as transformer.py
?
Other names here are "category", so a larger categorical name might sound more suitable than a concrete graph pass name. Wdyt?
For longer Q sequence lengths (typically >1024), it's beneficial to calculate the attention by an algorithm (inspired by Lazy Softmax) that is processing Q in chunks. The overall memory usage and execution time (given it's executed concurrently, e.g. on ANE) should be better, and in certain cases when models encounter OOMs for longer sequence lengths, models using this algorithm still work.
This PR implements a new graph pass that can optionally transform the MIL operation
ios18.scaled_dot_product_attention
into a set of operations calculating the attention by chunks of Q.Parameters of the new graph pass:
min_seq_length
(default: 1280) - the original MIL operation will only be transformed if the sequence length of Q is greater than or equal to this value.seq_length_divider
(default: 16) - defines the size of chunks (based on:chunk_size = sequence_length / seq_length_divider
)Example of performance of Depth-Anything model running on ANE:
execution time: 131.55 ms
memory usage: 169.67 MB
execution time: 86.84 ms
memory usage: 93.34 MB
CI pipeline run: https://gitlab.com/coremltools1/coremltools/-/pipelines/1600785656