Skip to content

Commit

Permalink
Added primitives for speculative decoding and tests (#598)
Browse files Browse the repository at this point in the history
This PR creates a DistributedLlamaModelForSpeculativeGeneration that implements basic speculative decoding (currently for greedy inference only).
  • Loading branch information
xtinkt authored Jul 24, 2024
1 parent a2d4b65 commit 02bbd85
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 17 deletions.
36 changes: 21 additions & 15 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,24 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"

@property
def position(self):
return self._position

@position.setter
def position(self, start_from_position: int):
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None

def step(
self,
inputs: torch.Tensor,
prompts: torch.Tensor,
hypo_ids: torch.LongTensor,
*,
step_id: str,
start_from_position: int,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
Expand All @@ -100,12 +110,6 @@ def step(
if self.closed:
raise Exception("Session is closed, cannot perform step")

if start_from_position is not None:
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None

n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
Expand All @@ -127,8 +131,8 @@ def step(
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
if start_from_position is not None:
request_metadata["start_from_position"] = start_from_position
if self._position is not None:
request_metadata["start_from_position"] = self._position
elif self.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
Expand Down Expand Up @@ -235,6 +239,13 @@ def num_blocks(self) -> int:
def position(self) -> int:
return self._position

@position.setter
def position(self, start_from_position: int) -> None:
self._position = start_from_position
for session in self._server_sessions:
assert isinstance(session, _ServerInferenceSession)
session.position = start_from_position

def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
server_sessions = []
try:
Expand Down Expand Up @@ -275,12 +286,7 @@ def step(
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
start_from_position: Optional[int] = None,
) -> torch.Tensor:

if start_from_position is not None:
self._position = start_from_position

assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
Expand Down Expand Up @@ -324,12 +330,12 @@ def step(
self._update_sequence(server_idx, block_idx, attempt_no)

server_session = self._server_sessions[server_idx]
assert server_session.position == self.position, f"{server_session.position} and {self.position}"
inputs = server_session.step(
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids,
step_id=step_id,
start_from_position=start_from_position,
)

server_idx += 1
Expand Down
2 changes: 2 additions & 0 deletions src/petals/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
DistributedLlamaForSequenceClassification,
DistributedLlamaModel,
)
from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
from petals.utils.auto_config import register_model_classes

register_model_classes(
config=DistributedLlamaConfig,
model=DistributedLlamaModel,
model_for_causal_lm=DistributedLlamaForCausalLM,
model_for_speculative=DistributedLlamaForSpeculativeGeneration,
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
)
111 changes: 111 additions & 0 deletions src/petals/models/llama/speculative_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional, Union

import torch
from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama import LlamaForCausalLM

from petals.models.llama.config import DistributedLlamaConfig
from petals.models.llama.model import DistributedLlamaForCausalLM


class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
DistributedLlamaForCausalLM.__init__(self, config)
self.small_model = small_model

def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
speculative_inference_iteration_size: int = 10,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
assert not generation_config.do_sample, "sample is not working for speculative generation now"
assert not synced_gpus, "synced_gpus is not working for speculative generation now"
assert (
not generation_config.return_dict_in_generate
), "return_dict_in_generate is not working for speculative generation now"

has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)

# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
finished = False
firsts = True

while not finished:
speculative_inference_iteration_size = min(
speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1]
)
with torch.no_grad():
speculative_outputs = self.small_model.generate(
input_ids,
max_new_tokens=speculative_inference_iteration_size,
do_sample=False,
)
speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]

full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]

input_for_validation = full_sequence
if not firsts:
self.active_session.position = input_ids.shape[1] - 1
input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :]
else:
firsts = False
input_for_validation = input_for_validation[:, :-1]
with torch.no_grad():
precise_model_outputs = self(input_for_validation)
full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()

all_valid_tokens = []
first_token = None
for i in range(speculative_inference_iteration_size):
token_logits = full_token_logits[:, i, :]
token_scores = logits_processor(
input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits
)
valid_token = torch.argmax(token_scores, dim=-1)

if first_token is None:
first_token = valid_token

if valid_token.item() == speculative_tokens[:, i].item():
all_valid_tokens.append(valid_token.unsqueeze(-1))
else:
break

if not all_valid_tokens and first_token is not None:
all_valid_tokens.append(first_token.unsqueeze(-1))
all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (
1 - unfinished_sequences
)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)

if streamer is not None:
streamer.put(all_valid_tokens.cpu())

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
finished = unfinished_sequences.max() == 0

del precise_model_outputs

if streamer is not None:
streamer.end()

return input_ids
1 change: 1 addition & 0 deletions src/petals/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
AutoDistributedModel,
AutoDistributedModelForCausalLM,
AutoDistributedModelForSequenceClassification,
AutoDistributedSpeculativeModel,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos
5 changes: 5 additions & 0 deletions src/petals/utils/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class _ModelClasses:
config: Type[PretrainedConfig]
model: Optional[Type[PreTrainedModel]] = None
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
model_for_speculative: Optional[Type[PreTrainedModel]] = None
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None


Expand Down Expand Up @@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase
_mapping_field = "model_for_causal_lm"


class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_speculative"


class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_sequence_classification"
54 changes: 52 additions & 2 deletions tests/test_speculative_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

import pytest
import torch
import transformers

from petals import AutoDistributedConfig, RemoteSequential
from petals import (
AutoDistributedConfig,
AutoDistributedSpeculativeModel,
DistributedLlamaForSpeculativeGeneration,
RemoteSequential,
)
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
Expand All @@ -26,10 +32,54 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
initial_outputs_inference = sess.step(inputs)
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
sess.position = 2
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)

ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(short_inputs)

assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)


@pytest.fixture
def noisy_model():
noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
lm_head = noisy_model.get_output_embeddings()
assert isinstance(lm_head, torch.nn.Linear)
with torch.no_grad():
lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
return noisy_model


@pytest.fixture
def model():
return transformers.AutoModelForCausalLM.from_pretrained(
MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)


@pytest.fixture
def tokenizer():
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)


@pytest.mark.forked
@pytest.mark.skipif(
"llama" not in MODEL_NAME.lower(),
reason="Speculative generation now works only for llama models",
)
def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
)

inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]

generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)

assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)

0 comments on commit 02bbd85

Please sign in to comment.