Skip to content

Commit

Permalink
correct block costs and flops
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Aug 5, 2024
1 parent 57c58b7 commit 3967bee
Showing 1 changed file with 114 additions and 7 deletions.
121 changes: 114 additions & 7 deletions src/nanotron/models/gpt3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,6 @@ def forward(
return fp32_sharded_logits, hidden_encoder_states["aux_losses"]


# TODO: maybe reimplement:
# - get_block_compute_costs
# - get_flops_per_sec
class GPT3MoEForTraining(GPT3ForTraining):
def __init__(
self,
Expand Down Expand Up @@ -258,17 +255,127 @@ def forward(
loss[key] = value
return loss

# TODO: adapt with MoE costs
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
# active experts + routing
mlp_cost = 2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok \
+ model_config.hidden_size * model_config.moe_num_experts
att_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
GPTBlock: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 2 * d_ff * model_config.hidden_size,
GPTBlock: att_cost + mlp_cost,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
return block_compute_costs

def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
world_size = self.parallel_context.world_pg.size()
model_flops, hardware_flops = get_flops(
num_layers=self.config.num_hidden_layers,
hidden_size=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
vocab_size=self.config.vocab_size,
ffn_hidden_size=self.config.n_inner if self.config.n_inner is not None else 4 * self.config.hidden_size,
seq_len=sequence_length,
batch_size=global_batch_size,
kv_channels=None,
glu_activation=False,
num_experts=self.config.moe_num_experts,
num_experts_per_tok=self.config.num_experts_per_tok,
)
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
return model_flops_per_s, hardware_flops_per_s


def get_flops(
num_layers,
hidden_size,
num_heads,
vocab_size,
seq_len,
kv_channels=None,
ffn_hidden_size=None,
batch_size=1,
glu_activation=False,
num_experts=1,
num_experts_per_tok=1,
):
"""Counts flops in an decoder-only model
Args:
num_layers: number of decoder layers
hidden_size: hidden size of the model
num_heads: number of heads in the model
kv_channels: hidden size of the key and value heads
ffn_hidden_size: hidden size of the FFN
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
glu_activation: Whether to use GLU activation in FFN. Check T5 v1.1 for more info.
num_experts_per_tok: number of experts per token in the MoE layer
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
"""

if kv_channels is None:
assert hidden_size % num_heads == 0
kv_channels = hidden_size // num_heads
if ffn_hidden_size is None:
ffn_hidden_size = 4 * hidden_size

# In the following we mark the reduced dimension with parentheses
# decoder
# self attention (MQA)
## q projection
decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels
## kv projection, shared across heads
decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels
## qk logits
decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len
### SWA (sliding window attention / local attention)
# window_size = 4096
# decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size
## v logits
decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels
# decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels
## attn out
decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size
# FF
## 1st layer
decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
if glu_activation:
# 3 matmuls instead of 2 in FFN
# ref. https://arxiv.org/pdf/2002.05202.pdf
# Used for example in T5 v1.1
decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
## 2nd layer
decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
# MoE router
decoder_ffn_router_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_experts

decoder_flops_fwd = (
decoder_q_proj_flops_fwd
+ decoder_kv_proj_flops_fwd
+ decoder_qk_logits_flops_fwd
+ decoder_v_logits_flops_fwd
+ decoder_attn_out_flops_fwd
+ decoder_ffn_1_flops_fwd * num_experts_per_tok
+ decoder_ffn_2_flops_fwd * num_experts_per_tok
+ decoder_ffn_router_flops_fwd
)

# lm head
lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size

# the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd

hardware_flops = model_flops # TODO @nouamanetazi: This is a placeholder for now
return model_flops, hardware_flops

0 comments on commit 3967bee

Please sign in to comment.