Skip to content

Commit

Permalink
global_buffer add hidden_states and attention_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
qianhao0713 committed Aug 20, 2024
1 parent da7a8cc commit 9c0ce85
Showing 1 changed file with 86 additions and 36 deletions.
122 changes: 86 additions & 36 deletions src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,15 @@ def init(self, num_layers, offload_percent, shape, dtype, device):
self.layer_num = num_layers
self.gpu_layer_num = int(num_layers * offload_percent)
self.cpu_layer_num = num_layers - self.gpu_layer_num
self.gpu_buffer = [torch.empty(shape, dtype=dtype, device=device) for _ in range(self.gpu_layer_num)]
self.gpu_buffer = [None for _ in range(self.gpu_layer_num)]
self.cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape, dtype=np.float16)
bs, num_heads, seq_len, emb_size = shape
shape_h = [bs, seq_len, num_heads * emb_size]
shape_a = [bs, seq_len]
self.hidden_state_gpu_buffer = [None for _ in range(self.gpu_layer_num)]
self.hidden_state_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_h, dtype=np.float16)
self.position_id_gpu_buffer = [None for _ in range(self.gpu_layer_num)]
self.position_id_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_a, dtype=np.float16)
self.d2h_stream = drv.Stream()
self.h2d_streams = [drv.Stream() for _ in range(self.gpu_layer_num)]
self.initialized = True
Expand All @@ -61,26 +68,60 @@ def save_flash_attn_out(self, layer_idx, out):
drv.memcpy_dtoh_async(self.cpu_buffer[layer_idx], out.data_ptr(), self.d2h_stream)
else:
idx = layer_idx - self.cpu_layer_num
drv.memcpy_dtod(self.gpu_buffer[idx].data_ptr(), out.data_ptr(), out.element_size() * out.nelement())
self.gpu_buffer[idx] = out

def save_hidden_states(self, layer_idx, *hs):
if layer_idx < 0:
layer_idx = self.layer_num + layer_idx
hidden_state = hs[0]
position_id = hs[1]
if layer_idx < self.cpu_layer_num:
drv.memcpy_dtoh_async(self.hidden_state_cpu_buffer[layer_idx], hidden_state.data_ptr(), self.d2h_stream)
drv.memcpy_dtoh_async(self.position_id_cpu_buffer[layer_idx], position_id.data_ptr(), self.d2h_stream)
else:
idx = layer_idx - self.cpu_layer_num
self.hidden_state_gpu_buffer[idx] = hidden_state
self.position_id_gpu_buffer[idx] = position_id

def get_flash_attn_out(self, layer_idx):
if layer_idx < 0:
layer_idx = self.layer_num + layer_idx
if layer_idx > self.cpu_layer_num:
if layer_idx >= self.cpu_layer_num:
return self.gpu_buffer[layer_idx - self.cpu_layer_num]
idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num
idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num
self.h2d_streams[idx].synchronize()
return self.gpu_buffer[idx]

def get_hidden_states(self, layer_idx):
if layer_idx < 0:
layer_idx = self.layer_num + layer_idx
if layer_idx >= self.cpu_layer_num:
return self.hidden_state_gpu_buffer[layer_idx - self.cpu_layer_num], self.position_id_gpu_buffer[layer_idx - self.cpu_layer_num]
idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num
self.h2d_streams[idx].synchronize()
return self.hidden_state_gpu_buffer[idx], self.position_id_gpu_buffer[idx]

def free_flash_attn_out(self, layer_idx):
def free_layer_gpu_buffer(self, layer_idx):
if layer_idx < 0:
layer_idx = self.layer_num + layer_idx
if layer_idx == self.layer_num - 1:
self.d2h_stream.synchronize()
cpu_layer_idx = layer_idx - self.gpu_layer_num
if layer_idx >= self.cpu_layer_num:
idx = layer_idx - self.cpu_layer_num
else:
idx = self.gpu_layer_num -1 - (self.cpu_layer_num - 1 - layer_idx) % self.gpu_layer_num
self.gpu_buffer[idx].grad = None
if cpu_layer_idx < 0:
self.gpu_buffer[idx] = None
self.hidden_state_gpu_buffer[idx] = None
self.position_id_gpu_buffer[idx] = None
return
idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num
self.gpu_buffer[idx].grad = None
drv.memcpy_htod_async(self.gpu_buffer[idx].data_ptr(), self.cpu_buffer[cpu_layer_idx], self.h2d_streams[idx])
drv.memcpy_htod_async(self.hidden_state_gpu_buffer[idx].data_ptr(), self.hidden_state_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx])
drv.memcpy_htod_async(self.position_id_gpu_buffer[idx].data_ptr(), self.position_id_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx])

global_buffer = GlobalBufferManager()

def init_flash_attn_buffers(num_layers):
# update the global buffer according to number of layers
Expand Down Expand Up @@ -161,16 +202,24 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args):
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
ctx.tensor_indices = {}
tensor_inputs = []
global global_buffer
hidden_state = None
position_ids = None
for i, arg in enumerate(args):
if i == 0 and ctx.layer_idx != 0:
# flash attention output is saved to the global buffer during forward
ctx.inputs.append(None)
else:
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
# tensor_inputs.append(arg)
if len(arg.shape) == 3:
hidden_state = arg
ctx.tensor_indices[i] = 'hidden_state'
elif len(arg.shape) == 2:
position_ids = arg
ctx.tensor_indices[i] = 'position_ids'
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
Expand All @@ -185,16 +234,13 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args):

# save flash attention output to global buffer
# save_flash_attn_out_to_global_buffer(ctx.layer_idx, out)
GlobalBufferManager().save_flash_attn_out(ctx.layer_idx, out)
tensor_inputs += [softmax_lse]

global_buffer.save_flash_attn_out(ctx.layer_idx, out)
global_buffer.save_hidden_states(ctx.layer_idx, hidden_state, position_ids)
# tensor_inputs += [softmax_lse]
ctx.softmax_scale = softmax_scale

ctx.save_for_backward(*tensor_inputs)
tensor_inputs_ma = 0
for ti in tensor_inputs:
tensor_inputs_ma += ti.element_size() * ti.nelement()
if int(os.environ['RANK']) == 0:
print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}, tensor_inputs_ma: {tensor_inputs_ma/(1<<30):2f}")
ctx.save_for_backward(softmax_lse)
return out, residual

@staticmethod
Expand All @@ -207,18 +253,24 @@ def backward(ctx, *args):
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
tensors, softmax_lse = tensors[:-1], tensors[-1]

# tensors = ctx.saved_tensors
softmax_lse = ctx.saved_tensors[0]
# tensors, softmax_lse = tensors[:-1], tensors[-1]
global global_buffer
hidden_state, position_ids = global_buffer.get_hidden_states(ctx.layer_idx)
# Fill in inputs with appropriate saved tensors.
# Fill the flash attention output first
if ctx.layer_idx > 0:
# inputs[0] should be flash attention output
# inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1)
inputs[0] = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1)
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]

inputs[0] = global_buffer.get_flash_attn_out(ctx.layer_idx-1)
# for i, idx in enumerate(tensor_indices):
# inputs[idx] = tensors[i]
for k, v in tensor_indices.items():
if v == 'hidden_state':
inputs[k] = hidden_state
if v == 'position_ids':
inputs[k] = position_ids
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
Expand Down Expand Up @@ -247,9 +299,8 @@ def backward(ctx, *args):
#dk = torch.empty(k.shape, dtype=q.dtype, device=q.device)
#dv = torch.empty(v.shape, dtype=q.dtype, device=q.device)
# out = get_flash_attn_out_from_global_buffer(ctx.layer_idx)
out = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx)
if int(os.environ['RANK']) == 0:
print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}")
out = global_buffer.get_flash_attn_out(ctx.layer_idx)

# todo get dout
dout = args[0]

Expand All @@ -263,13 +314,12 @@ def backward(ctx, *args):

grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)

# write flash attention output gradients to buffer
if ctx.layer_idx > 0:
# write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad)
GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad
global_buffer.get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad
# free_flash_attn_out_buffer(ctx.layer_idx)
GlobalBufferManager().free_flash_attn_out(ctx.layer_idx)
global_buffer.free_layer_gpu_buffer(ctx.layer_idx)
return (None, None, None) + grads


Expand Down Expand Up @@ -341,12 +391,12 @@ def backward(ctx, *args):
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors

global global_buffer
# Fill in inputs with appropriate saved tensors.
# Fill the flash attention output first
# inputs[0] should be flash attention output
# inputs[0] = get_flash_attn_out_from_global_buffer(-1)
inputs[0] = GlobalBufferManager().get_flash_attn_out(-1)
inputs[0] = global_buffer.get_flash_attn_out(-1)
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]

Expand Down Expand Up @@ -384,10 +434,9 @@ def backward(ctx, *args):
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)

# write flash attention output gradients to buffer
# write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad)
GlobalBufferManager().get_flash_attn_out(-1).grad = detached_inputs[0].grad
global_buffer.get_flash_attn_out(-1).grad = detached_inputs[0].grad

return (None, None) + grads

Expand Down Expand Up @@ -549,7 +598,8 @@ def forward(
pass
# initialize the global buffer
# init_flash_attn_buffers(len(self.layers))
GlobalBufferManager().init(
global global_buffer
global_buffer.init(
self.config.num_hidden_layers,
offload_percent=0.25,
shape=[batch_size, self.config.num_attention_heads, seq_length, self.config.hidden_size // self.config.num_attention_heads],
Expand Down

0 comments on commit 9c0ce85

Please sign in to comment.