diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 5606a23f78..d4080b657d 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -48,8 +48,15 @@ def init(self, num_layers, offload_percent, shape, dtype, device): self.layer_num = num_layers self.gpu_layer_num = int(num_layers * offload_percent) self.cpu_layer_num = num_layers - self.gpu_layer_num - self.gpu_buffer = [torch.empty(shape, dtype=dtype, device=device) for _ in range(self.gpu_layer_num)] + self.gpu_buffer = [None for _ in range(self.gpu_layer_num)] self.cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape, dtype=np.float16) + bs, num_heads, seq_len, emb_size = shape + shape_h = [bs, seq_len, num_heads * emb_size] + shape_a = [bs, seq_len] + self.hidden_state_gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.hidden_state_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_h, dtype=np.float16) + self.position_id_gpu_buffer = [None for _ in range(self.gpu_layer_num)] + self.position_id_cpu_buffer = drv.pagelocked_empty([self.cpu_layer_num] + shape_a, dtype=np.float16) self.d2h_stream = drv.Stream() self.h2d_streams = [drv.Stream() for _ in range(self.gpu_layer_num)] self.initialized = True @@ -61,26 +68,60 @@ def save_flash_attn_out(self, layer_idx, out): drv.memcpy_dtoh_async(self.cpu_buffer[layer_idx], out.data_ptr(), self.d2h_stream) else: idx = layer_idx - self.cpu_layer_num - drv.memcpy_dtod(self.gpu_buffer[idx].data_ptr(), out.data_ptr(), out.element_size() * out.nelement()) + self.gpu_buffer[idx] = out + + def save_hidden_states(self, layer_idx, *hs): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + hidden_state = hs[0] + position_id = hs[1] + if layer_idx < self.cpu_layer_num: + drv.memcpy_dtoh_async(self.hidden_state_cpu_buffer[layer_idx], hidden_state.data_ptr(), self.d2h_stream) + drv.memcpy_dtoh_async(self.position_id_cpu_buffer[layer_idx], position_id.data_ptr(), self.d2h_stream) + else: + idx = layer_idx - self.cpu_layer_num + self.hidden_state_gpu_buffer[idx] = hidden_state + self.position_id_gpu_buffer[idx] = position_id def get_flash_attn_out(self, layer_idx): if layer_idx < 0: layer_idx = self.layer_num + layer_idx - if layer_idx > self.cpu_layer_num: + if layer_idx >= self.cpu_layer_num: return self.gpu_buffer[layer_idx - self.cpu_layer_num] idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num self.h2d_streams[idx].synchronize() return self.gpu_buffer[idx] + + def get_hidden_states(self, layer_idx): + if layer_idx < 0: + layer_idx = self.layer_num + layer_idx + if layer_idx >= self.cpu_layer_num: + return self.hidden_state_gpu_buffer[layer_idx - self.cpu_layer_num], self.position_id_gpu_buffer[layer_idx - self.cpu_layer_num] + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num + self.h2d_streams[idx].synchronize() + return self.hidden_state_gpu_buffer[idx], self.position_id_gpu_buffer[idx] - def free_flash_attn_out(self, layer_idx): + def free_layer_gpu_buffer(self, layer_idx): if layer_idx < 0: layer_idx = self.layer_num + layer_idx + if layer_idx == self.layer_num - 1: + self.d2h_stream.synchronize() cpu_layer_idx = layer_idx - self.gpu_layer_num + if layer_idx >= self.cpu_layer_num: + idx = layer_idx - self.cpu_layer_num + else: + idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num + self.gpu_buffer[idx].grad = None if cpu_layer_idx < 0: + self.gpu_buffer[idx] = None + self.hidden_state_gpu_buffer[idx] = None + self.position_id_gpu_buffer[idx] = None return - idx = self.gpu_layer_num -1 - (self.cpu_layer_num - layer_idx) % self.gpu_layer_num - self.gpu_buffer[idx].grad = None drv.memcpy_htod_async(self.gpu_buffer[idx].data_ptr(), self.cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + drv.memcpy_htod_async(self.hidden_state_gpu_buffer[idx].data_ptr(), self.hidden_state_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + drv.memcpy_htod_async(self.position_id_gpu_buffer[idx].data_ptr(), self.position_id_cpu_buffer[cpu_layer_idx], self.h2d_streams[idx]) + +global_buffer = GlobalBufferManager() def init_flash_attn_buffers(num_layers): # update the global buffer according to number of layers @@ -161,16 +202,24 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] - ctx.tensor_indices = [] + ctx.tensor_indices = {} tensor_inputs = [] + global global_buffer + hidden_state = None + position_ids = None for i, arg in enumerate(args): if i == 0 and ctx.layer_idx != 0: # flash attention output is saved to the global buffer during forward ctx.inputs.append(None) else: if torch.is_tensor(arg): - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) + # tensor_inputs.append(arg) + if len(arg.shape) == 3: + hidden_state = arg + ctx.tensor_indices[i] = 'hidden_state' + elif len(arg.shape) == 2: + position_ids = arg + ctx.tensor_indices[i] = 'position_ids' ctx.inputs.append(None) else: ctx.inputs.append(arg) @@ -185,16 +234,13 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): # save flash attention output to global buffer # save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) - GlobalBufferManager().save_flash_attn_out(ctx.layer_idx, out) - tensor_inputs += [softmax_lse] + + global_buffer.save_flash_attn_out(ctx.layer_idx, out) + global_buffer.save_hidden_states(ctx.layer_idx, hidden_state, position_ids) + # tensor_inputs += [softmax_lse] ctx.softmax_scale = softmax_scale - ctx.save_for_backward(*tensor_inputs) - tensor_inputs_ma = 0 - for ti in tensor_inputs: - tensor_inputs_ma += ti.element_size() * ti.nelement() - if int(os.environ['RANK']) == 0: - print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}, tensor_inputs_ma: {tensor_inputs_ma/(1<<30):2f}") + ctx.save_for_backward(softmax_lse) return out, residual @staticmethod @@ -207,18 +253,24 @@ def backward(ctx, *args): # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors - tensors, softmax_lse = tensors[:-1], tensors[-1] - + # tensors = ctx.saved_tensors + softmax_lse = ctx.saved_tensors[0] + # tensors, softmax_lse = tensors[:-1], tensors[-1] + global global_buffer + hidden_state, position_ids = global_buffer.get_hidden_states(ctx.layer_idx) # Fill in inputs with appropriate saved tensors. # Fill the flash attention output first if ctx.layer_idx > 0: # inputs[0] should be flash attention output # inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) - inputs[0] = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1) - for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - + inputs[0] = global_buffer.get_flash_attn_out(ctx.layer_idx-1) + # for i, idx in enumerate(tensor_indices): + # inputs[idx] = tensors[i] + for k, v in tensor_indices.items(): + if v == 'hidden_state': + inputs[k] = hidden_state + if v == 'position_ids': + inputs[k] = position_ids # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. @@ -247,9 +299,8 @@ def backward(ctx, *args): #dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) #dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) # out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) - out = GlobalBufferManager().get_flash_attn_out(ctx.layer_idx) - if int(os.environ['RANK']) == 0: - print(f"layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated()/(1<<30):.2f}, MR: {torch.cuda.memory_reserved()/(1<<30):.2f}") + out = global_buffer.get_flash_attn_out(ctx.layer_idx) + # todo get dout dout = args[0] @@ -263,13 +314,12 @@ def backward(ctx, *args): grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) - # write flash attention output gradients to buffer if ctx.layer_idx > 0: # write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) - GlobalBufferManager().get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad + global_buffer.get_flash_attn_out(ctx.layer_idx-1).grad = detached_inputs[0].grad # free_flash_attn_out_buffer(ctx.layer_idx) - GlobalBufferManager().free_flash_attn_out(ctx.layer_idx) + global_buffer.free_layer_gpu_buffer(ctx.layer_idx) return (None, None, None) + grads @@ -341,12 +391,12 @@ def backward(ctx, *args): inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - + global global_buffer # Fill in inputs with appropriate saved tensors. # Fill the flash attention output first # inputs[0] should be flash attention output # inputs[0] = get_flash_attn_out_from_global_buffer(-1) - inputs[0] = GlobalBufferManager().get_flash_attn_out(-1) + inputs[0] = global_buffer.get_flash_attn_out(-1) for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] @@ -384,10 +434,9 @@ def backward(ctx, *args): torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) - # write flash attention output gradients to buffer # write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) - GlobalBufferManager().get_flash_attn_out(-1).grad = detached_inputs[0].grad + global_buffer.get_flash_attn_out(-1).grad = detached_inputs[0].grad return (None, None) + grads @@ -549,7 +598,8 @@ def forward( pass # initialize the global buffer # init_flash_attn_buffers(len(self.layers)) - GlobalBufferManager().init( + global global_buffer + global_buffer.init( self.config.num_hidden_layers, offload_percent=0.25, shape=[batch_size, self.config.num_attention_heads, seq_length, self.config.hidden_size // self.config.num_attention_heads],