Skip to content

Commit

Permalink
remove unnes
Browse files Browse the repository at this point in the history
  • Loading branch information
artek0chumak committed Mar 12, 2024
1 parent 0cbc38c commit 5cab602
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions src/petals/models/mixtral/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

import hivemind
import torch
import torch.nn as nn
from hivemind import DHT
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral import (
Expand Down Expand Up @@ -31,7 +31,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi

config_class = DistributedMixtralConfig

def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[hivemind.DHT] = None):
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
super().__init__(config)
assert len(self.layers) == 0
Expand Down Expand Up @@ -122,18 +122,10 @@ def forward(
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
return self.embed_tokens

@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return nn.Identity()

@property
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
return self.layers

@property
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return self.norm


class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
Expand Down

0 comments on commit 5cab602

Please sign in to comment.