-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support saving and loading 8-bit block weights #273
base: main
Are you sure you want to change the base?
Conversation
|
||
logger = get_logger(__name__) | ||
|
||
CLIENT_BRANCH = "main" | ||
BLOCK_BRANCH_PREFIX = "block_" | ||
BLOCK_BRANCH_PREFIX = "int8_block" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll roll that back before merging
if load_in_8bit: | ||
block = replace_8bit_linear(block) | ||
block = block.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved replace_8bit_linear
here because it's not possible to correctly load the quantized Linear8bitLt checkpoint into the model before it's converted and quantized
src/petals/utils/convert_block.py
Outdated
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt | ||
|
||
for n, module in model.named_children(): | ||
if len(list(module.children())) > 0: | ||
replace_8bit_linear(module, threshold) | ||
|
||
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: | ||
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}" | ||
model._modules[n] = CustomLinear8bitLt( | ||
model._modules[n] = bnb.nn.Linear8bitLt( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not strictly necessary, but it'd be good to get rid of all bitsandbytes-related code that got into upstream before merging this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in #297.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gentle reminder: please update BNB before merging. This is not covered by tests
56a3bee
to
a610f4d
Compare
@@ -38,6 +39,8 @@ def load_pretrained_block( | |||
use_auth_token: Optional[str] = None, | |||
cache_dir: Optional[str] = None, | |||
max_disk_space: Optional[int] = None, | |||
load_in_8bit=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_in_8bit=False, | |
load_in_8bit: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please defer this until #323 is merged, since it changes block loading code.
We discussed that we may revive this feature for loading NF4-pre-quantized weights for Llama 2 and Stable Beluga 2. |
This PR relies on bitsandbytes-foundation/bitsandbytes#159 and makes it possible to call
convert_model
with the int8 data type and later on download the 8-bit checkpoint instead of 16-bit if serving the model withload_in_8bit=True
. This can save up to 2x bandwidth on starting a server, as shown by this comparison of model sizes for bloom-560m:The command that was used for conversion is
python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model_int8 --torch_dtype int8 --resize_token_embeddings 50000 --block_branch_prefix int8_block
. To test that the checkpoint loads correctly, you need to install bitsandbytes from the branch in the PR above and runpython -m petals.cli.run_server bigscience/test-bloomd --new_swarm --skip_reachability_check --throughput 100 --device cuda
(pay attention that I had to changeBLOCK_BRANCH_PREFIX
in this branch for the sake of testing).