-
Notifications
You must be signed in to change notification settings - Fork 528
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add somehow workable version * Fix generation * Fixes * Choose right attn * style * fix bloom * remove unnes * Update src/petals/models/mixtral/model.py Co-authored-by: Max Ryabinin <mryabinin0@gmail.com> * fix order of init --------- Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
- Loading branch information
1 parent
2ad0b2b
commit d2fcbbc
Showing
7 changed files
with
344 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from petals.models.bloom import * | ||
from petals.models.falcon import * | ||
from petals.models.llama import * | ||
from petals.models.mixtral import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from petals.models.mixtral.block import WrappedMixtralBlock | ||
from petals.models.mixtral.config import DistributedMixtralConfig | ||
from petals.models.mixtral.model import ( | ||
DistributedMixtralForCausalLM, | ||
DistributedMixtralForSequenceClassification, | ||
DistributedMixtralModel, | ||
) | ||
from petals.utils.auto_config import register_model_classes | ||
|
||
register_model_classes( | ||
config=DistributedMixtralConfig, | ||
model=DistributedMixtralModel, | ||
model_for_causal_lm=DistributedMixtralForCausalLM, | ||
model_for_sequence_classification=DistributedMixtralForSequenceClassification, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from transformers import MixtralConfig | ||
from transformers.cache_utils import DynamicCache | ||
from transformers.modeling_attn_mask_utils import ( | ||
_prepare_4d_causal_attention_mask, | ||
_prepare_4d_causal_attention_mask_for_sdpa, | ||
) | ||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel | ||
|
||
|
||
class WrappedMixtralBlock(MixtralDecoderLayer): | ||
def __init__(self, config: MixtralConfig, layer_idx: int): | ||
super().__init__(config, layer_idx) | ||
|
||
self._attn_implementation = config._attn_implementation | ||
self.sliding_window = config.sliding_window | ||
self.layer_idx = layer_idx | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
*args, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
layer_past: Optional[Tuple[torch.Tensor]] = None, | ||
use_cache: bool = False, | ||
**kwargs | ||
): | ||
batch_size, seq_length, _ = hidden_states.shape | ||
|
||
seq_length_with_past = seq_length | ||
past_key_values_length = 0 | ||
|
||
past_key_value = layer_past | ||
if past_key_value is not None: | ||
past_key_values_length = past_key_value[0].shape[2] | ||
seq_length_with_past = seq_length_with_past + past_key_values_length | ||
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length) | ||
past_key_value = DynamicCache() | ||
for idx in range(self.layer_idx): | ||
past_key_value.update( | ||
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx | ||
) | ||
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx) | ||
|
||
if self._attn_implementation == "flash_attention_2": | ||
# 2d mask is passed through the layers | ||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | ||
elif self._attn_implementation == "sdpa": | ||
# output_attentions=True can not be supported when using SDPA, and we fall back on | ||
# the manual implementation that requires a 4D causal mask in all cases. | ||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
hidden_states, | ||
past_key_values_length, | ||
) | ||
else: | ||
# 4d mask is passed through the layers | ||
attention_mask = _prepare_4d_causal_attention_mask( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
hidden_states, | ||
past_key_values_length, | ||
sliding_window=self.sliding_window, | ||
) | ||
|
||
position_ids = torch.arange( | ||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device | ||
) | ||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | ||
|
||
outputs = super().forward( | ||
hidden_states, | ||
*args, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_value=past_key_value, | ||
use_cache=use_cache, | ||
**kwargs | ||
) | ||
|
||
if use_cache: | ||
present_key_value = outputs[-1] | ||
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx] | ||
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past) | ||
outputs = outputs[:-1] + (present_key_value,) | ||
|
||
return outputs | ||
|
||
def _reorder_cache_from_bloom( | ||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int | ||
) -> Tuple[torch.Tensor]: | ||
# TODO: Move to mixin | ||
key_states, value_states = key_value | ||
key_states = key_states.permute(0, 2, 1) | ||
key_states = key_states.view( | ||
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim | ||
) | ||
value_states = value_states.view(*key_states.shape) | ||
return (key_states, value_states) | ||
|
||
def _reorder_cache_to_bloom( | ||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int | ||
) -> Tuple[torch.Tensor]: | ||
# TODO: Move to mixin | ||
key_states, value_states = key_value | ||
value_states = value_states.view( | ||
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim | ||
) | ||
key_states = key_states.view(*value_states.shape) | ||
key_states = key_states.permute(0, 2, 1) | ||
return (key_states, value_states) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
from typing import Optional, Union | ||
|
||
from hivemind import get_logger | ||
from transformers.models.mixtral import MixtralConfig | ||
from transformers.models.mixtral.modeling_mixtral import MixtralAttention | ||
|
||
from petals.client.config import ClientConfig | ||
from petals.client.lm_head import LMHeadConfig | ||
from petals.client.ptune import PTuneConfig | ||
from petals.models.mixtral.block import WrappedMixtralBlock | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig): | ||
block_class = WrappedMixtralBlock | ||
attn_class = MixtralAttention | ||
block_prefix = "model.layers" | ||
|
||
num_key_value_groups = 1 | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs | ||
): | ||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) | ||
if loading_from_repo and dht_prefix is None: | ||
dht_prefix = str(model_name_or_path) | ||
dht_prefix = dht_prefix.replace(".", "-") | ||
logger.info(f"Using DHT prefix: {dht_prefix}") | ||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) | ||
config = result[0] if isinstance(result, tuple) else result | ||
if config.pad_token_id is None: | ||
config.pad_token_id = 0 | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
from hivemind import DHT | ||
from hivemind.utils.logging import get_logger | ||
from transformers.modeling_outputs import MoeModelOutputWithPast | ||
from transformers.models.mixtral import ( | ||
MixtralForCausalLM, | ||
MixtralForSequenceClassification, | ||
MixtralModel, | ||
MixtralPreTrainedModel, | ||
) | ||
|
||
from petals.client.from_pretrained import FromPretrainedMixin | ||
from petals.client.lm_head import LMHead | ||
from petals.client.ptune import PTuneMixin | ||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues | ||
from petals.client.remote_sequential import RemoteSequential | ||
from petals.models.mixtral.config import DistributedMixtralConfig | ||
from petals.utils.auto_config import DefaultRevisionMixin | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel): | ||
"""MixtralModel, but all transformer layers are hosted by the swarm""" | ||
|
||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing | ||
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."] | ||
|
||
config_class = DistributedMixtralConfig | ||
|
||
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None): | ||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization | ||
super().__init__(config) | ||
assert len(self.layers) == 0 | ||
config.num_hidden_layers = n_layer | ||
|
||
self.layers = RemoteSequential(config, dht=dht) | ||
|
||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm | ||
self.init_prompts(config) | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[RemotePastKeyValues] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
head_mask: Optional[torch.LongTensor] = None, | ||
inputs_embeds: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
output_router_logits: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
): | ||
if input_ids is not None and inputs_embeds is not None: | ||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | ||
elif input_ids is not None: | ||
input_shape = input_ids.size() | ||
input_ids = input_ids.view(-1, input_shape[-1]) | ||
elif inputs_embeds is not None: | ||
input_shape = inputs_embeds.size()[:-1] | ||
else: | ||
raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
||
# The causal mask will be added on the server-side | ||
assert ( | ||
attention_mask is None or (attention_mask == 1).all() | ||
), f"Custom attention masks are not supported, {attention_mask=}" | ||
assert ( | ||
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() | ||
), f"Non-consecutive position_ids are not supported, {position_ids=}" | ||
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" | ||
assert use_cache is None or use_cache, f"{use_cache=} is not supported" | ||
assert not output_attentions, f"{output_attentions=} is not supported" | ||
assert not output_hidden_states, f"{output_hidden_states=} is not supported" | ||
assert return_dict is None or return_dict, f"{return_dict=} is not supported" | ||
assert not output_router_logits, f"{output_router_logits=} is not supported" | ||
|
||
if inputs_embeds is None: | ||
inputs_embeds = self.embed_tokens(input_ids) | ||
|
||
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 | ||
if use_prompts: | ||
batch_size = inputs_embeds.shape[0] | ||
prompts, intermediate_prompts = self.get_prompt(batch_size) | ||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) | ||
else: | ||
prompts = intermediate_prompts = None | ||
|
||
hidden_states = inputs_embeds | ||
output_shape = input_shape + (hidden_states.size(-1),) | ||
|
||
if past_key_values is None: | ||
past_key_values = RemotePastKeyValues() | ||
past_key_values.update_seen(hidden_states.size(1)) | ||
|
||
hidden_states = self.layers( | ||
hidden_states, | ||
prompts=intermediate_prompts, | ||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, | ||
) | ||
|
||
# Remove prefix | ||
if use_prompts: | ||
hidden_states = hidden_states[:, self.pre_seq_len :] | ||
|
||
# Add last hidden state | ||
hidden_states = self.norm(hidden_states) | ||
hidden_states = hidden_states.view(output_shape) | ||
return MoeModelOutputWithPast( | ||
last_hidden_state=hidden_states, | ||
past_key_values=past_key_values, | ||
hidden_states=None, | ||
attentions=None, | ||
) | ||
|
||
@property | ||
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin | ||
return self.embed_tokens | ||
|
||
@property | ||
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin | ||
return self.layers | ||
|
||
|
||
class DistributedMixtralForCausalLM( | ||
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM | ||
): | ||
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing | ||
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected | ||
|
||
config_class = DistributedMixtralConfig | ||
|
||
def __init__(self, config: DistributedMixtralConfig): | ||
MixtralPreTrainedModel.__init__(self, config) | ||
self.model = DistributedMixtralModel(config) | ||
self.lm_head = LMHead(config) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
def get_output_embeddings(self): | ||
return self.lm_head | ||
|
||
@property | ||
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin | ||
return self.model | ||
|
||
|
||
class DistributedMixtralForSequenceClassification( | ||
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification | ||
): | ||
def __init__(self, config: DistributedMixtralConfig): | ||
MixtralPreTrainedModel.__init__(self, config) | ||
self.num_labels = config.num_labels | ||
|
||
self.model = DistributedMixtralModel(config) | ||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
@property | ||
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin | ||
return self.model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters