Skip to content

Commit

Permalink
use global_buffer to load/unload activation
Browse files Browse the repository at this point in the history
  • Loading branch information
qianhao0713 committed Aug 12, 2024
1 parent f4f9659 commit da7a8cc
Showing 1 changed file with 84 additions and 15 deletions.
99 changes: 84 additions & 15 deletions src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

from .lightseq_async_attn import _lightseq_forward, _lightseq_backward
from .async_communication import initialize_distributed, reset_global_memory_buffer

import deepspeed as ds
from transformers.cache_utils import Cache
import pycuda
import pycuda.driver as drv
import pycuda.autoinit
import numpy as np
import time, os

# define a global buffer to save flash attention outputs
# it's called global because it saves the outputs for all layers
Expand All @@ -29,6 +32,56 @@
# hooks for the gradients of residual
global_hooks = []

class Singleton(object):
_instance = None
def __new__(class_, *args, **kwargs):
if not isinstance(class_._instance, class_):
class_._instance = object.__new__(class_, *args, **kwargs)
return class_._instance

class GlobalBufferManager(Singleton):

def init(self, num_layers, offload_percent, shape, dtype, device):
torch.cuda.empty_cache()
if hasattr(self, 'initialized'):
return
self.layer_num = num_layers
self.gpu_layer_num = int(num_layers * offload_percent)
self.cpu_layer_num = num_layers - self.gpu_layer_num
self.gpu_buffer = [torch.empty(shape, dtype=dtype, device=device) for _ in range(self.gpu_layer_num)]
self.cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape, dtype=np.float16)
self.d2h_stream = drv.Stream()
self.h2d_streams = [drv.Stream() for _ in range(self.gpu_layer_num)]
self.initialized = True

def save_flash_attn_out(self, layer_idx, out):
if layer_idx < 0:
layer_idx = self.layer_num + layer_idx
if layer_idx < self.cpu_layer_num:
drv.memcpy_dtoh_async(self.cpu_buffer[layer_idx], out.data_ptr(), self.d2h_stream)
else:
idx = layer_idx - self.cpu_layer_num
drv.memcpy_dtod(self.gpu_buffer[idx].data_ptr(), out.data_ptr(), out.element_size() * out.nelement())

def get_flash_attn_out(self, layer_idx):
if layer_idx < 0:
layer_idx = self.layer_num + layer_idx
if layer_idx > self.cpu_layer_num:
return self.gpu_buffer[layer_idx - self.cpu_layer_num]
idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num
self.h2d_streams[idx].synchronize()
return self.gpu_buffer[idx]

def free_flash_attn_out(self, layer_idx):
if layer_idx < 0:
layer_idx = self.layer_num + layer_idx
cpu_layer_idx = layer_idx - self.gpu_layer_num
if cpu_layer_idx < 0:
return
idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num
self.gpu_buffer[idx].grad = None
drv.memcpy_htod_async(self.gpu_buffer[idx].data_ptr(), self.cpu_buffer[cpu_layer_idx], self.h2d_streams[idx])

def init_flash_attn_buffers(num_layers):
# update the global buffer according to number of layers
global global_flash_attn_out_buffer
Expand Down Expand Up @@ -131,13 +184,17 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args):
rng_state = None

# save flash attention output to global buffer
save_flash_attn_out_to_global_buffer(ctx.layer_idx, out)
ds.runtime.utils.see_memory_usage(f"forward, layer={ctx.layer_idx}", force=True)
# save_flash_attn_out_to_global_buffer(ctx.layer_idx, out)
GlobalBufferManager().save_flash_attn_out(ctx.layer_idx, out)
tensor_inputs += [softmax_lse]
ctx.softmax_scale = softmax_scale

ctx.save_for_backward(*tensor_inputs)

tensor_inputs_ma = 0
for ti in tensor_inputs:
tensor_inputs_ma += ti.element_size() * ti.nelement()
if int(os.environ['RANK']) == 0:
print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}, tensor_inputs_ma: {tensor_inputs_ma/(1<<30):2f}")
return out, residual

@staticmethod
Expand All @@ -157,7 +214,8 @@ def backward(ctx, *args):
# Fill the flash attention output first
if ctx.layer_idx > 0:
# inputs[0] should be flash attention output
inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1)
# inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1)
inputs[0] = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1)
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]

Expand Down Expand Up @@ -188,7 +246,10 @@ def backward(ctx, *args):
#dq = torch.empty(q.shape, dtype=q.dtype, device=q.device)
#dk = torch.empty(k.shape, dtype=q.dtype, device=q.device)
#dv = torch.empty(v.shape, dtype=q.dtype, device=q.device)
out = get_flash_attn_out_from_global_buffer(ctx.layer_idx)
# out = get_flash_attn_out_from_global_buffer(ctx.layer_idx)
out = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx)
if int(os.environ['RANK']) == 0:
print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}")
# todo get dout
dout = args[0]

Expand All @@ -205,9 +266,10 @@ def backward(ctx, *args):

# write flash attention output gradients to buffer
if ctx.layer_idx > 0:
write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad)
free_flash_attn_out_buffer(ctx.layer_idx)
ds.runtime.utils.see_memory_usage(f"backward, layer={ctx.layer_idx}", force=True)
# write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad)
GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad
# free_flash_attn_out_buffer(ctx.layer_idx)
GlobalBufferManager().free_flash_attn_out(ctx.layer_idx)
return (None, None, None) + grads


Expand Down Expand Up @@ -266,7 +328,6 @@ def forward(ctx, run_function, preserve_rng_state, *args):

with torch.no_grad():
outputs = run_function(*args)
ds.runtime.utils.see_memory_usage(f"forward, layer=last", force=True)
return outputs

@staticmethod
Expand All @@ -284,7 +345,8 @@ def backward(ctx, *args):
# Fill in inputs with appropriate saved tensors.
# Fill the flash attention output first
# inputs[0] should be flash attention output
inputs[0] = get_flash_attn_out_from_global_buffer(-1)
# inputs[0] = get_flash_attn_out_from_global_buffer(-1)
inputs[0] = GlobalBufferManager().get_flash_attn_out(-1)
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]

Expand Down Expand Up @@ -324,7 +386,8 @@ def backward(ctx, *args):
for inp in detached_inputs)

# write flash attention output gradients to buffer
write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad)
# write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad)
GlobalBufferManager().get_flash_attn_out(-1).grad = detached_inputs[0].grad

return (None, None) + grads

Expand Down Expand Up @@ -485,7 +548,14 @@ def forward(
except:
pass
# initialize the global buffer
init_flash_attn_buffers(len(self.layers))
# init_flash_attn_buffers(len(self.layers))
GlobalBufferManager().init(
self.config.num_hidden_layers,
offload_percent=0.25,
shape=[batch_size, self.config.num_attention_heads, seq_length, self.config.hidden_size // self.config.num_attention_heads],
dtype=hidden_states.dtype,
device=hidden_states.device
)

if use_cache:
try:
Expand Down Expand Up @@ -654,7 +724,6 @@ def llama_model_forward(
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
ds.runtime.utils.see_memory_usage(f"forward end", force=True)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
Expand Down

0 comments on commit da7a8cc

Please sign in to comment.