Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
artek0chumak committed Feb 27, 2024
1 parent 574f8eb commit 5155d13
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/petals/models/mixtral/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import DynamicCache
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel


class WrappedMixtralBlock(MixtralDecoderLayer):
Expand Down Expand Up @@ -38,7 +38,9 @@ def forward(
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
past_key_value = DynamicCache()
for idx in range(self.layer_idx):
past_key_value.update(torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx)
past_key_value.update(
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
)
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)

if self._attn_implementation == "flash_attention_2":
Expand Down Expand Up @@ -81,9 +83,7 @@ def forward(
if use_cache:
present_key_value = outputs[-1]
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
present_key_value = self._reorder_cache_to_bloom(
present_key_value, batch_size, seq_length_with_past
)
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
outputs = outputs[:-1] + (present_key_value,)

return outputs
Expand Down
5 changes: 3 additions & 2 deletions src/petals/models/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return self.norm


class DistributedMixtralForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected

Expand All @@ -160,7 +162,6 @@ def transformer(self) -> DistributedMixtralModel: # For compatibility with Remo
class DistributedMixtralForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
):

def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
Expand Down

0 comments on commit 5155d13

Please sign in to comment.