From 5155d13a17c766065396a34f740ce8621164b41b Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Tue, 27 Feb 2024 15:10:33 +0400 Subject: [PATCH] style --- src/petals/models/mixtral/block.py | 12 ++++++------ src/petals/models/mixtral/model.py | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/petals/models/mixtral/block.py b/src/petals/models/mixtral/block.py index 8dee322cc..c9fcd8d78 100644 --- a/src/petals/models/mixtral/block.py +++ b/src/petals/models/mixtral/block.py @@ -2,12 +2,12 @@ import torch from transformers import MixtralConfig -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel +from transformers.cache_utils import DynamicCache from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) -from transformers.cache_utils import DynamicCache +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel class WrappedMixtralBlock(MixtralDecoderLayer): @@ -38,7 +38,9 @@ def forward( _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length) past_key_value = DynamicCache() for idx in range(self.layer_idx): - past_key_value.update(torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx) + past_key_value.update( + torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx + ) past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx) if self._attn_implementation == "flash_attention_2": @@ -81,9 +83,7 @@ def forward( if use_cache: present_key_value = outputs[-1] present_key_value = present_key_value.to_legacy_cache()[self.layer_idx] - present_key_value = self._reorder_cache_to_bloom( - present_key_value, batch_size, seq_length_with_past - ) + present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past) outputs = outputs[:-1] + (present_key_value,) return outputs diff --git a/src/petals/models/mixtral/model.py b/src/petals/models/mixtral/model.py index 13a8b32e7..1b52dcf6f 100644 --- a/src/petals/models/mixtral/model.py +++ b/src/petals/models/mixtral/model.py @@ -135,7 +135,9 @@ def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin return self.norm -class DistributedMixtralForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM): +class DistributedMixtralForCausalLM( + DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM +): _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected @@ -160,7 +162,6 @@ def transformer(self) -> DistributedMixtralModel: # For compatibility with Remo class DistributedMixtralForSequenceClassification( DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification ): - def __init__(self, config: DistributedMixtralConfig): MixtralPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels