From 02bbd85ed89cd55e09fc37d0a07480b9f5f52634 Mon Sep 17 00:00:00 2001 From: Anton Sinitsin <30695750+xtinkt@users.noreply.github.com> Date: Wed, 24 Jul 2024 17:49:45 +0300 Subject: [PATCH] Added primitives for speculative decoding and tests (#598) This PR creates a DistributedLlamaModelForSpeculativeGeneration that implements basic speculative decoding (currently for greedy inference only). --- src/petals/client/inference_session.py | 36 +++--- src/petals/models/llama/__init__.py | 2 + src/petals/models/llama/speculative_model.py | 111 +++++++++++++++++++ src/petals/utils/__init__.py | 1 + src/petals/utils/auto_config.py | 5 + tests/test_speculative_generation.py | 54 ++++++++- 6 files changed, 192 insertions(+), 17 deletions(-) create mode 100644 src/petals/models/llama/speculative_model.py diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 4d94e7a76..5472d68ae 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -83,6 +83,17 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[ if not next_input_message.uid and not next_input_message.tensors: break # this message means "done sending" + @property + def position(self): + return self._position + + @position.setter + def position(self, start_from_position: int): + assert start_from_position <= self._position + self._position = start_from_position + if self.history is not None and self.history.shape[1] >= start_from_position: + self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None + def step( self, inputs: torch.Tensor, @@ -90,7 +101,6 @@ def step( hypo_ids: torch.LongTensor, *, step_id: str, - start_from_position: int, ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs @@ -100,12 +110,6 @@ def step( if self.closed: raise Exception("Session is closed, cannot perform step") - if start_from_position is not None: - assert start_from_position <= self._position - self._position = start_from_position - if self.history is not None and self.history.shape[1] >= start_from_position: - self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None - n_input_tokens = inputs.shape[1] if self.history is None: self.history = inputs @@ -127,8 +131,8 @@ def step( request_metadata = dict(session_id=self.session_id, step_id=step_id) if not self.stepped: request_metadata.update(self.session_metadata) - if start_from_position is not None: - request_metadata["start_from_position"] = start_from_position + if self._position is not None: + request_metadata["start_from_position"] = self._position elif self.config.use_server_to_server: next_servers = self._collect_next_servers() if next_servers: @@ -235,6 +239,13 @@ def num_blocks(self) -> int: def position(self) -> int: return self._position + @position.setter + def position(self, start_from_position: int) -> None: + self._position = start_from_position + for session in self._server_sessions: + assert isinstance(session, _ServerInferenceSession) + session.position = start_from_position + def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]: server_sessions = [] try: @@ -275,12 +286,7 @@ def step( inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, - start_from_position: Optional[int] = None, ) -> torch.Tensor: - - if start_from_position is not None: - self._position = start_from_position - assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") @@ -324,12 +330,12 @@ def step( self._update_sequence(server_idx, block_idx, attempt_no) server_session = self._server_sessions[server_idx] + assert server_session.position == self.position, f"{server_session.position} and {self.position}" inputs = server_session.step( inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id, - start_from_position=start_from_position, ) server_idx += 1 diff --git a/src/petals/models/llama/__init__.py b/src/petals/models/llama/__init__.py index e5d8aa4f9..2f8f597bc 100644 --- a/src/petals/models/llama/__init__.py +++ b/src/petals/models/llama/__init__.py @@ -5,11 +5,13 @@ DistributedLlamaForSequenceClassification, DistributedLlamaModel, ) +from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration from petals.utils.auto_config import register_model_classes register_model_classes( config=DistributedLlamaConfig, model=DistributedLlamaModel, model_for_causal_lm=DistributedLlamaForCausalLM, + model_for_speculative=DistributedLlamaForSpeculativeGeneration, model_for_sequence_classification=DistributedLlamaForSequenceClassification, ) diff --git a/src/petals/models/llama/speculative_model.py b/src/petals/models/llama/speculative_model.py new file mode 100644 index 000000000..f8b8faea3 --- /dev/null +++ b/src/petals/models/llama/speculative_model.py @@ -0,0 +1,111 @@ +from typing import Optional, Union + +import torch +from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama import LlamaForCausalLM + +from petals.models.llama.config import DistributedLlamaConfig +from petals.models.llama.model import DistributedLlamaForCausalLM + + +class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin): + def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM): + DistributedLlamaForCausalLM.__init__(self, config) + self.small_model = small_model + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + logits_warper: Optional[LogitsProcessorList], + speculative_inference_iteration_size: int = 10, + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + assert not generation_config.do_sample, "sample is not working for speculative generation now" + assert not synced_gpus, "synced_gpus is not working for speculative generation now" + assert ( + not generation_config.return_dict_in_generate + ), "return_dict_in_generate is not working for speculative generation now" + + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + finished = False + firsts = True + + while not finished: + speculative_inference_iteration_size = min( + speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1] + ) + with torch.no_grad(): + speculative_outputs = self.small_model.generate( + input_ids, + max_new_tokens=speculative_inference_iteration_size, + do_sample=False, + ) + speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:] + + full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1) + assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1] + + input_for_validation = full_sequence + if not firsts: + self.active_session.position = input_ids.shape[1] - 1 + input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :] + else: + firsts = False + input_for_validation = input_for_validation[:, :-1] + with torch.no_grad(): + precise_model_outputs = self(input_for_validation) + full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone() + + all_valid_tokens = [] + first_token = None + for i in range(speculative_inference_iteration_size): + token_logits = full_token_logits[:, i, :] + token_scores = logits_processor( + input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits + ) + valid_token = torch.argmax(token_scores, dim=-1) + + if first_token is None: + first_token = valid_token + + if valid_token.item() == speculative_tokens[:, i].item(): + all_valid_tokens.append(valid_token.unsqueeze(-1)) + else: + break + + if not all_valid_tokens and first_token is not None: + all_valid_tokens.append(first_token.unsqueeze(-1)) + all_valid_tokens = torch.cat(all_valid_tokens, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * ( + 1 - unfinished_sequences + ) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1) + + if streamer is not None: + streamer.put(all_valid_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None) + finished = unfinished_sequences.max() == 0 + + del precise_model_outputs + + if streamer is not None: + streamer.end() + + return input_ids diff --git a/src/petals/utils/__init__.py b/src/petals/utils/__init__.py index c8aa4844a..ac25ed2e2 100644 --- a/src/petals/utils/__init__.py +++ b/src/petals/utils/__init__.py @@ -3,5 +3,6 @@ AutoDistributedModel, AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification, + AutoDistributedSpeculativeModel, ) from petals.utils.dht import declare_active_modules, get_remote_module_infos diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index 0cec83d87..e6adfcee5 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -15,6 +15,7 @@ class _ModelClasses: config: Type[PretrainedConfig] model: Optional[Type[PreTrainedModel]] = None model_for_causal_lm: Optional[Type[PreTrainedModel]] = None + model_for_speculative: Optional[Type[PreTrainedModel]] = None model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None @@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase _mapping_field = "model_for_causal_lm" +class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase): + _mapping_field = "model_for_speculative" + + class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_sequence_classification" diff --git a/tests/test_speculative_generation.py b/tests/test_speculative_generation.py index e3045dea3..5d436bb88 100644 --- a/tests/test_speculative_generation.py +++ b/tests/test_speculative_generation.py @@ -2,8 +2,14 @@ import pytest import torch +import transformers -from petals import AutoDistributedConfig, RemoteSequential +from petals import ( + AutoDistributedConfig, + AutoDistributedSpeculativeModel, + DistributedLlamaForSpeculativeGeneration, + RemoteSequential, +) from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS from petals.server.from_pretrained import load_pretrained_block from test_utils import * @@ -26,10 +32,54 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato with torch.inference_mode(): with remote_block.inference_session(max_length=inputs.shape[1]) as sess: initial_outputs_inference = sess.step(inputs) - secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2) + sess.position = 2 + secondary_outputs_inference = sess.step(short_inputs[:, 2:, :]) result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) (outputs_local,) = ref_block(short_inputs) assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference) + + +@pytest.fixture +def noisy_model(): + noisy_model = transformers.AutoModelForCausalLM.from_pretrained( + REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 + ) + lm_head = noisy_model.get_output_embeddings() + assert isinstance(lm_head, torch.nn.Linear) + with torch.no_grad(): + lm_head.weight += torch.randn_like(lm_head.weight) * 0.02 + return noisy_model + + +@pytest.fixture +def model(): + return transformers.AutoModelForCausalLM.from_pretrained( + MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 + ) + + +@pytest.fixture +def tokenizer(): + # We set use_fast=False since LlamaTokenizerFast is slow on load + return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) + + +@pytest.mark.forked +@pytest.mark.skipif( + "llama" not in MODEL_NAME.lower(), + reason="Speculative generation now works only for llama models", +) +def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3): + speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained( + MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model + ) + + inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] + + generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False) + generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False) + + assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)