Skip to content

Commit

Permalink
[Feature] qlora support (#5586)
Browse files Browse the repository at this point in the history
* [feature] qlora support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* qlora follow commit

* migrate qutization folder to colossalai/

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and ver217 committed Apr 28, 2024
1 parent 8954a0c commit 91fa553
Show file tree
Hide file tree
Showing 14 changed files with 640 additions and 143 deletions.
15 changes: 15 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,18 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
---------------- LICENSE FOR Hugging Face accelerate ----------------

Copyright 2021 The HuggingFace Team

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
16 changes: 10 additions & 6 deletions applications/Colossal-LLaMA/colossal_llama/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,19 @@ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch

# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
(
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
)
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
(
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
)
for instance in instances
]

Expand Down
8 changes: 5 additions & 3 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,11 @@ def main() -> None:
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")

optimizer = HybridAdam(
model_params=filter(lambda p: p.requires_grad, model.parameters())
if args.freeze_non_embeds_params
else model.parameters(),
model_params=(
filter(lambda p: p.requires_grad, model.parameters())
if args.freeze_non_embeds_params
else model.parameters()
),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
Expand Down
23 changes: 21 additions & 2 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig

from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
Expand Down Expand Up @@ -230,7 +231,12 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
return self.plugin.no_sync(model, optimizer)

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: "peft.LoraConfig" = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
quantize=False,
) -> nn.Module:
"""
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
Expand Down Expand Up @@ -259,7 +265,20 @@ def enable_lora(
assert (
pretrained_dir is not None
), "Please provide pretrained directory path if not passing in lora configuration."
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
if quantize is True:
if bnb_quantization_config is not None:
warnings.warn(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
)
else:
bnb_quantization_config = BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)

return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)

def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.
Expand Down
10 changes: 9 additions & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer

from .dp_plugin_base import DPPluginBase
Expand Down Expand Up @@ -338,14 +339,21 @@ def support_lora(self) -> bool:
return True

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model

assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")

if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)

if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
Expand Down
10 changes: 9 additions & 1 deletion colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model

from .dp_plugin_base import DPPluginBase

Expand Down Expand Up @@ -237,10 +238,17 @@ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[Non
return model.module.no_sync()

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model

if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)

assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
if pretrained_dir is None:
return get_peft_model(model, lora_config)
Expand Down
38 changes: 19 additions & 19 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc
##### Llama

| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
|:-----------------------:|:------:|:------:|:------:|
| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
| colossal-inference | 326.4 | 582.72 | 816.64 |

Expand All @@ -174,7 +174,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc
#### Bloom

| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
|:-----------------------:|:------:|:------:|:------:|
| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
| colossal-inference | 323.28 | 538.52 | 611.64 |

Expand All @@ -187,40 +187,40 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t

#### A10 7b, fp16

| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :-------------------------: | :---: | :---:| :---: | :---: | :---: | :---: |
| Pipeline Inference | 40.35 | 77.10| 139.03| 232.70| 257.81| OOM |
| Hugging Face | 41.43 | 65.30| 91.93 | 114.62| OOM | OOM |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) |
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:|
| Pipeline Inference | 40.35 | 77.10 | 139.03 | 232.70 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM |


![ppllama7b](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama7b.png)

#### A10 13b, fp16

| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) |
|:----------------------------:|:-----:|:-----:|:-----:|:-----:|
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |

![ppllama13](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama13b.png)


#### A800 7b, fp16

| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|:----------------------------:|:-----:|:------:|:------:|:------:|:------:|
| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |

![ppllama7b_a800](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a800-llama7b.png)

### Quantization LLama

| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| auto-gptq | 199.20 | 232.56 | 253.26 |
| smooth-quant | 142.28 | 222.96 | 300.59 |
| colossal-gptq | 231.98 | 388.87 | 573.03 |
| batch_size | 8 | 16 | 32 |
|:-------------:|:------:|:------:|:------:|
| auto-gptq | 199.20 | 232.56 | 253.26 |
| smooth-quant | 142.28 | 222.96 | 300.59 |
| colossal-gptq | 231.98 | 388.87 | 573.03 |

![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-quant.png)

Expand Down
7 changes: 7 additions & 0 deletions colossalai/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .bnb import quantize_model
from .bnb_config import BnbQuantizationConfig

__all__ = [
"BnbQuantizationConfig",
"quantize_model",
]
Loading

0 comments on commit 91fa553

Please sign in to comment.