Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Pass: scaled_dot_product_attention_sliced_q #2418

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

jirmasek
Copy link
Contributor

@jirmasek jirmasek commented Dec 16, 2024

For longer Q sequence lengths (typically >1024), it's beneficial to calculate the attention by an algorithm (inspired by Lazy Softmax) that is processing Q in chunks. The overall memory usage and execution time (given it's executed concurrently, e.g. on ANE) should be better, and in certain cases when models encounter OOMs for longer sequence lengths, models using this algorithm still work.

This PR implements a new graph pass that can optionally transform the MIL operation ios18.scaled_dot_product_attention into a set of operations calculating the attention by chunks of Q.

Parameters of the new graph pass:

  • min_seq_length (default: 1280) - the original MIL operation will only be transformed if the sequence length of Q is greater than or equal to this value.
  • seq_length_divider (default: 16) - defines the size of chunks (based on: chunk_size = sequence_length / seq_length_divider)

Example of performance of Depth-Anything model running on ANE:

  • original:
    execution time: 131.55 ms
    memory usage: 169.67 MB
  • with transformations applied by this graph pass:
    execution time: 86.84 ms
    memory usage: 93.34 MB

CI pipeline run: https://gitlab.com/coremltools1/coremltools/-/pipelines/1600785656

@junpeiz
Copy link
Collaborator

junpeiz commented Dec 17, 2024

Great! Could you also add some concrete numbers in the PR description? For example, the memory and execution time of a model before/after using the slice Q algorithm?

Copy link
Collaborator

@junpeiz junpeiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code LGTM. For the CI, could you pin peft==0.13.2 to see if you could get a green CI?

@jirmasek jirmasek force-pushed the jim/graph-pass-sdpa-sliced-q branch from 62daa78 to 05c5fc3 Compare December 23, 2024 18:56
@junpeiz junpeiz self-requested a review January 2, 2025 22:27
Copy link
Collaborator

@junpeiz junpeiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! The perf improvement (especially the memory save) looks really promising!

Copy link
Collaborator

@YifanShenSZ YifanShenSZ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice graph pass!

(Previously I thought the parallelization is done in kernel 😂 looks like we need to give them hint haha)

@@ -44,5 +44,6 @@
optimize_state,
optimize_tensor_operation,
preprocess,
scaled_dot_product_attention_sliced_q,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we name the source file as transformer.py?

Other names here are "category", so a larger categorical name might sound more suitable than a concrete graph pass name. Wdyt?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants