-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefix caching #57
base: main
Are you sure you want to change the base?
Prefix caching #57
Conversation
Benchmark Speed and Reproduction ResultsThese results confirm initial results above but on the full set of classification evaluation tasks in MetaICL. Some formatting changes were required in the MetaICLTask (accounting for the one large Reproduction Error). But with those fixes we are now reproducing metrics closely to those from the MetaICL repo itself. Also of note is that caching actually produces a slight slowdown when the only overlap in inputs comes from multiple choice examples sharing the same context sequence over several continuation sequences. This is expected to be slower than the repeated ICL examples setting, but ideally should have a speed up proportional to the number of choices per example. Future work can attempt to address bottlenecks that are dominating the caching speed when there is relatively low overlap. truncation and catchable prefixesThe following table further explains the better speedup on repeated ICL examples (in the 16 shot setting) vs the multiple-choice only overlap (in the 0 shot setting). Of note the ratio of cacheable prefixes to the total number of inputs is much higher in the 0 shot setting. There is less overlap there and thus more different prefixes must actually be computed at some point, giving less of a speedup. |
Improved performance with batched cachingCommit 4e98c01 improves the caching code with batched computation of prefixes when there are multiple non shared prefixes within a single batch. The following results show that this gives caching a speed up even in the zeroshot setting, when there are multiple choice questions. The results also compare against the original MetaICL repo code. Those results are for raw gpt2-large while our results add IA3 adapters, so if we reran our code base without the IA3 adapters we would expect speed to improve or stay the same. |
You were saying you're still hacking on this and I should not review it? |
I may change the way that the prefixes are cached at the batch level and how those are used for the final inference to reduce memory footprint. But the part where I use the trie to organize examples by common prefix won't change from that. So to save you time, best not to review anything other than the trie and sorting stuff yet. But if you want to push out the functionality, this is a working implementation and I'm happy to add my optimizations in a new PR. |
@dirkgr, Qinyuan's work will not be needing any more caching improvements soon, so this branch won't be moving forward from that any more. I've changed the top post in the PR to reflect the current state of affairs. In particular note the updated limitations section which mentions the issues we talked about in person. I've also added a demo script |
What is this?
Prefix caching for DecoderOnlyRCModel that reuses overlapping prefixes between instances rather than recomputing them.
This is useful in at least two settings:
How do I use it?
Check out
experiments/prefix_cache_demo.py
for example usage.Limitations
Issues with max sequence length
One limitation is that caching with batching does not work well when close to the max input length. This is because both the cached past_keys_values and the continuation input tensor are padded to their largest length. The sum of the sequence length of these two tensors must be less than the max model. So longest input length is determined by the largest prefix and largest continuation together in a batch, even if they belong to different examples.
Over coming this issue would involve breaking batches into sub-batches or recording the examples to avoid problematic pairings of long prefixes and continuations. Right now if this issue is encountered an assertion is hit that states, "Presently batches with wide range of prefix and input lengths are not supported due overrun of max model size."
Cached Prefix Memory Usage
When batches contain more than one cached prefix, there is no memory efficient way to build the past_keys_values input tensor. That is the tensor must be a concatenation of different prefixes and thus cannot use
.expand()
to save memory. Instead the caching has a larger memory footprint than non-caching because prefixes are computed and then copied.A first level of improvement would be to do the copying in place some how so the memory footprint is at least the same size as non-caching. Improving beyond this by making use of
.expand()
would require splitting the continuation computation into sub-batches that all share the same prefix.Initial Result (outdated)
These results were made with the first iteration of this design on July 15. See below for more up to date results.
Speed up
The screenshot below gives initial results on the speed up from using caching in a few settings. Key take aways are:
Reproducibility
F1 and accuracy metrics are exactly reproduced and logits are all within
torch.allclose()
with and without caching, and with using the code frommain
. The following are results runningpython -m catwalk --model metaicl::gpt2 --task metaicl::boolq --num_shots 1 --fewshot_seed 100
:with caching
metaicl::boolq acc 0.4051987826824188
metaicl::boolq f1 tensor([0.5296, 0.1913])
metaicl::boolq precision tensor([0.3778, 0.6183])
metaicl::boolq recall tensor([0.8852, 0.1131])
without caching
metaicl::boolq acc 0.4051987826824188
metaicl::boolq f1 tensor([0.5296, 0.1913])
metaicl::boolq precision tensor([0.3778, 0.6183])
metaicl::boolq recall tensor([0.8852, 0.1131])
code from
main
metaicl::boolq acc 0.4051987826824188
metaicl::boolq f1 tensor([0.5296, 0.1913])
metaicl::boolq precision tensor([0.3778, 0.6183])
metaicl::boolq recall tensor([0.8852, 0.1131])