From bcb3faa1c29f0f5a7e33a7f6813ab590bdbe67a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 20 Jun 2024 09:56:04 +0200 Subject: [PATCH] Factor out sharding of packed tensors (#2059) For Phi-3-Small I need to shard a packed QKV bias tensor, for which I implemented the `Weights.get_packed_sharded` method. However, this method can also replace the `Weights._get_qweight` method and the custom sharding code from `Weights.get_weights_col_packed`. --- .../text_generation_server/utils/weights.py | 99 +++++++++++-------- 1 file changed, 60 insertions(+), 39 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 45cfc073ca3..e61425254a0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -130,29 +130,57 @@ def get_sharded(self, tensor_name: str, dim: int): ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]): - slice_ = self._get_slice(name) - total_size = slice_.get_shape()[1] + def get_packed_sharded( + self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + ) -> torch.Tensor: + """ + Get a shard from a tensor that packs multiple tensors. + + When a tensor packs multiple tensors (such as QKV or an up + projection + gate projection), sharding with `get_sharded` is not + safe since it would not split the packed tensors across shards. + + This method shards a tensor, such that the packed tensors are + split across shards. + + The columns are split in equally sized blocks when blocks is an `int`, or + in blocks proportional given to the sizes. For instance `[2, 1, 1]` will + divide an input with dimensionality `1024` in `[512, 256, 256]`. This is + convenient for e.g. splitting QKV without knowing the storage details of + quantized weights. + """ + slice_ = self._get_slice(tensor_name) + total_size = slice_.get_shape()[dim] block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes) world_size = self.process_group.size() rank = self.process_group.rank() - weights = [] + tensors = [] block_offset = 0 for block_size in block_sizes: assert ( block_size % world_size == 0 - ), f"Prepacked qkv cannot be sharded across {world_size} shards" + ), f"Prepacked tensor cannot be sharded across {world_size} shards" shard_block_size = block_size // world_size start = rank * shard_block_size stop = (rank + 1) * shard_block_size - weights.append(slice_[:, block_offset + start : block_offset + stop]) + if dim == 0: + tensor = slice_[block_offset + start : block_offset + stop] + elif dim == 1: + tensor = slice_[:, block_offset + start : block_offset + stop] + else: + raise NotImplementedError("Currently only dim=0 or dim=1 is supported") + tensors.append(tensor) block_offset += block_size + tensor = torch.cat(tensors, dim=dim) + tensor = tensor.to(device=self.device) - weight = torch.cat(weights, dim=1) - weight = weight.to(device=self.device) - return weight + # Avoid casting quantizer dtypes. + if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + tensor = tensor.to(dtype=self.dtype) + + return tensor def get_weights_col_packed_qkv( self, @@ -185,7 +213,9 @@ def get_weights_col_packed( from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) + qweight = self.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -193,8 +223,12 @@ def get_weights_col_packed( gptq_params = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes) - scales = self._get_qweight(f"{prefix}.scales", block_sizes) + qzeros = self.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + scales = self.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and gptq_params.quant_method == "gptq": @@ -237,13 +271,17 @@ def get_weights_col_packed( if quant_method == "gptq": gptq_params = self._get_gptq_params() try: - qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) + qweight = self.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" ) - scales = self._get_qweight(f"{prefix}.scales", block_sizes) + scales = self.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) g_idx = self.get_tensor(f"{prefix}.g_idx") weight = repack_gptq_for_marlin( qweight=qweight, @@ -257,34 +295,17 @@ def get_weights_col_packed( ) else: - B = self._get_qweight(f"{prefix}.B", block_sizes) - s = self._get_qweight(f"{prefix}.s", block_sizes) + B = self.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = self.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) weight = MarlinWeight(B=B, s=s) else: - slice_ = self._get_slice(f"{prefix}.weight") - total_size = slice_.get_shape()[0] - block_sizes = _blocks_to_block_sizes( - total_size=total_size, blocks=block_sizes + weight = self.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes ) - - world_size = self.process_group.size() - rank = self.process_group.rank() - - tensors = [] - block_offset = 0 - for block_size in block_sizes: - assert ( - block_size % world_size == 0 - ), f"Prepacked weights cannot be sharded across {world_size} shards" - shard_block_size = block_size // world_size - start = rank * shard_block_size - stop = (rank + 1) * shard_block_size - tensor = slice_[block_offset + start : block_offset + stop] - tensors.append(tensor) - block_offset += block_size - weight = torch.cat(tensors, dim=0) - weight = weight.to(device=self.device) - weight = weight.to(dtype=self.dtype) return weight def get_weights_col(self, prefix: str, quantize: str):