diff --git a/src/petals/models/mixtral/model.py b/src/petals/models/mixtral/model.py index 1b52dcf6f..798f91f45 100644 --- a/src/petals/models/mixtral/model.py +++ b/src/petals/models/mixtral/model.py @@ -1,8 +1,8 @@ from typing import Optional -import hivemind import torch import torch.nn as nn +from hivemind import DHT from hivemind.utils.logging import get_logger from transformers.modeling_outputs import MoeModelOutputWithPast from transformers.models.mixtral import ( @@ -31,7 +31,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi config_class = DistributedMixtralConfig - def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[hivemind.DHT] = None): + def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None): n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization super().__init__(config) assert len(self.layers) == 0 @@ -122,18 +122,10 @@ def forward( def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin return self.embed_tokens - @property - def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin - return nn.Identity() - @property def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin return self.layers - @property - def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin - return self.norm - class DistributedMixtralForCausalLM( DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM