Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix server warnings, update license links and readme #602

Merged
merged 12 commits into from
Jul 24, 2024
2 changes: 1 addition & 1 deletion src/petals/models/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
from transformers.models.bloom.modeling_bloom import BloomBlock, build_alibi_tensor

from petals.utils.misc import is_dummy

Expand Down
2 changes: 1 addition & 1 deletion src/petals/models/bloom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfi
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license")
logger.info("Make sure you follow the BLOOM terms of use: https://bit.ly/bloom-license")

loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None:
Expand Down
4 changes: 2 additions & 2 deletions src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
LlamaConfig,
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
repeat_kv,
rotate_half,
Expand Down Expand Up @@ -132,7 +131,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config)
self.self_attn = OptimizedLlamaAttention(config=config, layer_idx=0)
# layer_idx only matters for KV caching, and we re-implement it in Petals
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down
4 changes: 2 additions & 2 deletions src/petals/models/llama/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
logger.info(
"Make sure you follow the LLaMA's terms of use: "
"https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
"Make sure you follow the Llama terms of use: "
"https://llama.meta.com/llama3/license, https://llama.meta.com/llama2/license"
)

loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
Expand Down
3 changes: 1 addition & 2 deletions src/petals/models/mixtral/block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Optional, Tuple

import torch
Expand All @@ -8,7 +7,7 @@
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer


class WrappedMixtralBlock(MixtralDecoderLayer):
Expand Down
5 changes: 0 additions & 5 deletions src/petals/server/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def load_pretrained_block(
max_disk_space=max_disk_space,
)

# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=False)
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"

for param_name, _ in block.named_parameters():
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
Expand All @@ -76,7 +72,6 @@ def load_pretrained_block(
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)

logger.info(f"Loaded {model_name} block {block_index}")
logger.debug(f"Details: {report}")
return block


Expand Down
2 changes: 1 addition & 1 deletion src/petals/utils/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def estimate_adapter_memory_per_block(
**load_peft_kwargs,
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
with init_empty_weights(include_buffers=False):
block = get_model_block(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block)
Expand Down
Loading