Skip to content

Commit

Permalink
Fix beam search in GPU clients (#531)
Browse files Browse the repository at this point in the history
Fixes #503.
  • Loading branch information
borzunov authored Oct 23, 2023
1 parent 47d50e1 commit 82a97d6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 36 deletions.
22 changes: 9 additions & 13 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ jobs:
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
# [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
Expand All @@ -61,27 +60,25 @@ jobs:
until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
--mean_balance_check_period 10 \
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
--device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
sleep 10 # wait for the 1st server to choose blocks
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
--identity_path tests/server2.id \
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
$RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
SERVER2_PID=$!
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
SERVER3_PID=$!
# ^-- chunking test
python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
--initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
$RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet)
Expand Down Expand Up @@ -121,4 +118,3 @@ jobs:
# [Step 4] Clean up
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
echo "Done!"
38 changes: 15 additions & 23 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,7 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
break # this message means "done sending"

def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
*,
step_id: str,
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
Expand All @@ -114,21 +109,6 @@ def step(
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further

if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]

if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64

# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)

Expand Down Expand Up @@ -275,7 +255,9 @@ def __enter__(self) -> "InferenceSession":
assert not self._closed and not self._server_sessions
return self

def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
def step(
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
Expand All @@ -285,11 +267,21 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]

if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64

inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
hypo_ids = hypo_ids.cpu()
step_id = str(uuid.uuid4())

n_input_tokens = inputs.shape[1]
Expand All @@ -310,7 +302,7 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k

server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
)

server_idx += 1
Expand Down

0 comments on commit 82a97d6

Please sign in to comment.