Skip to content

Commit

Permalink
Add option to rollback inference for a certain number of steps (#588)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* style
  • Loading branch information
xtinkt authored Jul 9, 2024
1 parent 6858586 commit c0a4d2e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
32 changes: 29 additions & 3 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
break # this message means "done sending"

def step(
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
self,
inputs: torch.Tensor,
prompts: torch.Tensor,
hypo_ids: torch.LongTensor,
*,
step_id: str,
start_from_position: int,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
Expand All @@ -94,6 +100,12 @@ def step(
if self.closed:
raise Exception("Session is closed, cannot perform step")

if start_from_position is not None:
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None

n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
Expand All @@ -115,6 +127,8 @@ def step(
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
if start_from_position is not None:
request_metadata["start_from_position"] = start_from_position
elif self.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
Expand Down Expand Up @@ -257,8 +271,16 @@ def __enter__(self) -> "InferenceSession":
return self

def step(
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
start_from_position: Optional[int] = None,
) -> torch.Tensor:

if start_from_position is not None:
self._position = start_from_position

assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
Expand Down Expand Up @@ -303,7 +325,11 @@ def step(

server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids,
step_id=step_id,
start_from_position=start_from_position,
)

server_idx += 1
Expand Down
7 changes: 7 additions & 0 deletions src/petals/server/block_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ async def iterate_rpc_inference(
point_per_piece = points / max_length if max_length > 0 else 0.0

async for request, step_metadata in input_iterator:
if "start_from_position" in step_metadata:
start_from_position = step_metadata["start_from_position"]
assert (
prefix_length >= start_from_position,
), f"prefix_length={prefix_length}, start_from_position={start_from_position}"
prefix_length = start_from_position

flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
Expand Down
35 changes: 35 additions & 0 deletions tests/test_speculative_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import random

import pytest
import torch

from petals import AutoDistributedConfig, RemoteSequential
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *


@pytest.mark.forked
def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)

block_index = random.randint(0, config.num_hidden_layers - 1)
remote_block = remote_sequential[block_index]

inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
short_inputs[:, :2, :] = inputs[:, :2, :]

initial_outputs_inference = None
secondary_outputs_inference = None
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
initial_outputs_inference = sess.step(inputs)
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)

ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(short_inputs)

assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)

0 comments on commit c0a4d2e

Please sign in to comment.