Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor turbomind attention by precomputing rotary embed #2801

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
10 changes: 5 additions & 5 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,15 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)

// compute rope scaling factor
if (r->start_flag) {
seq.rope_theta = model_->attn_param_.rotary_embedding_base;
if (model_->attn_param_.use_dynamic_ntk) {
auto scaling_factor = model_->attn_param_.rope_scaling_factor;
seq.rope_theta = model_->attn_param_.rope.base;
if (model_->attn_param_.rope.type == RotaryScalingType::kDynamic) {
auto scaling_factor = model_->attn_param_.rope.factor;
if (scaling_factor >= 1.f) { // infer by current context length
auto max_seq_len = state.h_context_length[idx];
auto max_pos_emb = model_->attn_param_.max_position_embeddings;
auto max_pos_emb = model_->attn_param_.rope.max_position_embeddings;
if (max_seq_len > max_pos_emb) {
scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
float rope_dim = model_->attn_param_.rotary_embedding_dim;
float rope_dim = model_->attn_param_.rope.dim;
seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f));
TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
(long)seq.id,
Expand Down
53 changes: 38 additions & 15 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,45 @@ struct MoeParam {
std::vector<int> expert_num;
};

enum class RotaryScalingType
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
{
kDefault,
kLinear,
kDynamic,
kYarn,
kLlama3,
};

struct YarnRopeParam {
float attention_factor;
float beta_fast;
float beta_slow;
};

struct Llama3RopeParam {
float low_freq_factor;
float high_freq_factor;
int original_max_position_embeddings;
};

struct AttentionParam {
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
float softmax_scale;
std::string rope_scaling_type;
int original_max_position_embeddings;
float rope_scaling_factor;
float low_freq_factor;
float high_freq_factor;
float attention_factor;
float beta_fast;
float beta_slow;
bool use_dynamic_ntk;
bool use_logn_attn;
int cache_block_seq_len;
float softmax_scale;
int cache_block_seq_len;
bool use_logn_attn;
// rope
struct {
// common
RotaryScalingType type;
int dim;
float base;
float factor;
int max_position_embeddings;
// special
union {
YarnRopeParam yarn;
Llama3RopeParam llama3;
};
} rope;
};

struct EngineParam {
Expand Down
99 changes: 53 additions & 46 deletions src/turbomind/models/llama/rotary_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ RotaryScalingType GetRoPEType(const std::string& type)
{"linear", RotaryScalingType::kLinear},
{"dynamic", RotaryScalingType::kDynamic},
{"yarn", RotaryScalingType::kYarn},
{"llama3", RotaryScalingType::kLlama3},
{"mrope", RotaryScalingType::kMrope}};
{"llama3", RotaryScalingType::kLlama3}};
return lookup.at(type);
}

Expand All @@ -132,42 +131,52 @@ void RotaryEmbeddingV2::allocateBuffer(size_t token_num)
RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator):
stream_(stream), allocator_(allocator)
{
type_ = GetRoPEType(param.rope_scaling_type);
dim_ = param.rotary_embedding_dim;
rope_scaling_factor_ = 1.0f;
attention_factor_ = 1.0f;
type_ = param.rope.type;
dim_ = param.rope.dim;

if (type_ == RotaryScalingType::kLinear) {
rope_scaling_factor_ /= param.rope_scaling_factor;
}
else if (type_ == RotaryScalingType::kLlama3) {
const double PI = 3.14159265358979323846;
float inv_diff_freq_factor = 1.0 / (param.high_freq_factor - param.low_freq_factor);
llama3_inv_scaling_factor_ = 1.0 / param.rope_scaling_factor;
llama3_alpha_ = param.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor;
llama3_beta_ = param.low_freq_factor * inv_diff_freq_factor;
}
else if (type_ == RotaryScalingType::kYarn) {
const double PI = 3.14159265358979323846;
auto find_correction_dim = [&](float num_rotations) {
return (param.rotary_embedding_dim * std::log(param.max_position_embeddings / (num_rotations * 2 * PI)))
/ (2 * std::log(param.rotary_embedding_base));
};
auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) {
low = std::floor(find_correction_dim(low_rot));
high = std::ceil(find_correction_dim(high_rot));
low = std::max(low, 0.f);
high = std::min(high, param.rotary_embedding_dim - 1.f);
};
float low, high;
find_correction_range(param.beta_fast, param.beta_slow, low, high);
if (low == high) {
high += 0.01f;
switch (type_) {
case RotaryScalingType::kDefault:
break;
case RotaryScalingType::kLinear:
inv_factor_ = 1.0f / param.rope.factor;
break;
case RotaryScalingType::kDynamic:
inv_factor_ = param.rope.factor;
break;
case RotaryScalingType::kYarn: {
const double PI = 3.14159265358979323846;
auto find_correction_dim = [&](float num_rotations) {
return (param.rope.dim * std::log(param.rope.max_position_embeddings / (num_rotations * 2 * PI)))
/ (2 * std::log(param.rope.base));
};
auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) {
low = std::floor(find_correction_dim(low_rot));
high = std::ceil(find_correction_dim(high_rot));
low = std::max(low, 0.f);
high = std::min(high, param.rope.dim - 1.f);
};
float low, high;
find_correction_range(param.rope.yarn.beta_fast, param.rope.yarn.beta_slow, low, high);
if (low == high) {
high += 0.01f;
}
yarn_.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0;
yarn_.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low;
yarn_.yarn_inv_scaling_factor = (1 - 1.0 / param.rope.factor);
yarn_.attention_factor = param.rope.yarn.attention_factor;
break;
}
case RotaryScalingType::kLlama3: {
const double PI = 3.14159265358979323846;
float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor);
llama3_.llama3_inv_scaling_factor = 1.0 / param.rope.factor;
llama3_.llama3_alpha = param.rope.llama3.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor;
llama3_.llama3_beta = param.rope.llama3.low_freq_factor * inv_diff_freq_factor;
break;
}
yarn_ramp_inv_factor_div_2_ = 1.0 / (high - low) / 2.0;
yarn_ramp_inv_factor_mul_min_ = 1.0 / (high - low) * low;
yarn_inv_scaling_factor_ = (1 - 1.0 / param.rope_scaling_factor);
attention_factor_ = param.attention_factor;
default:
FT_CHECK(0);
break;
}
}

Expand All @@ -188,7 +197,7 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
rope_scaling_factor_,
inv_factor_,
cos_sin_);
break;
case RotaryScalingType::kLlama3:
Expand All @@ -198,9 +207,9 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
llama3_inv_scaling_factor_,
llama3_alpha_,
llama3_beta_,
llama3_.llama3_inv_scaling_factor,
llama3_.llama3_alpha,
llama3_.llama3_beta,
cos_sin_);
break;
case RotaryScalingType::kYarn:
Expand All @@ -210,14 +219,12 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
yarn_ramp_inv_factor_div_2_,
yarn_ramp_inv_factor_mul_min_,
yarn_inv_scaling_factor_,
attention_factor_,
yarn_.yarn_ramp_inv_factor_div_2,
yarn_.yarn_ramp_inv_factor_mul_min,
yarn_.yarn_inv_scaling_factor,
yarn_.attention_factor,
cos_sin_);
break;
case RotaryScalingType::kMrope:
FT_CHECK(0);
default:
FT_CHECK(0);
}
Expand Down
49 changes: 23 additions & 26 deletions src/turbomind/models/llama/rotary_emb.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

namespace turbomind {

enum class RotaryScalingType
{
kDefault,
kLinear,
kDynamic,
kYarn,
kLlama3,
kMrope
};
RotaryScalingType GetRoPEType(const std::string& type);

struct RotaryEmbeddingV2Params {
float* rope_theta;
Expand All @@ -23,6 +15,19 @@ struct RotaryEmbeddingV2Params {
int token_num;
};

struct InnerYarnRopeParam {
float attention_factor;
float yarn_ramp_inv_factor_div_2;
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
float yarn_ramp_inv_factor_mul_min;
float yarn_inv_scaling_factor;
};

struct InnerLlama3RopeParam {
float llama3_inv_scaling_factor;
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
float llama3_alpha;
float llama3_beta;
};

struct RotaryEmbeddingV2 {

RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator);
Expand All @@ -38,28 +43,20 @@ struct RotaryEmbeddingV2 {

void forward(const RotaryEmbeddingV2Params& params);

RotaryScalingType type_;
cudaStream_t const stream_;
IAllocator* const allocator_;

int dim_;
RotaryScalingType type_;
float inv_factor_{1.0};

union {
InnerYarnRopeParam yarn_;
InnerLlama3RopeParam llama3_;
};

// output
float* cos_sin_; // num_token x dim, (cos, sin, ...)

int dim_;
// default, linear, dynamic
float attention_factor_;
float rope_scaling_factor_;
float inv_scale_factor_;
// llama3
float llama3_inv_scaling_factor_;
float llama3_alpha_;
float llama3_beta_;
// yarn
float yarn_ramp_inv_factor_div_2_;
float yarn_ramp_inv_factor_mul_min_;
float yarn_inv_scaling_factor_;
// mrope
int3 mrope_section_;
};

}; // namespace turbomind
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/unified_attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ inline void UnifiedAttentionLayer<T>::forward(TensorMap* outputs, const TensorMa
}

// rope
params.rotary_embedding_dim = param_.rotary_embedding_dim;
params.max_position_embeddings = param_.max_position_embeddings;
params.rotary_embedding_dim = param_.rope.dim;
params.max_position_embeddings = param_.rope.max_position_embeddings;
params.cos_sin = cos_sin;
params.use_logn_attn = param_.use_logn_attn;

Expand Down
6 changes: 1 addition & 5 deletions src/turbomind/models/llama/unified_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ void UnifiedDecoder<T>::forwardSelfAttn(T* attn_io,
inputs.insert("cu_k_len", {MEMORY_GPU, TYPE_INT32, {batch_size + 1}, cu_k_len_});
inputs.insert("h_cu_q_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_q_len_});
inputs.insert("h_cu_k_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_k_len_});

if (rotary_emb_) {
inputs.insert("cos_sin",
{MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_});
}
inputs.insert("cos_sin", {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_});

TensorMap outputs(*_outputs);
outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io});
Expand Down
41 changes: 22 additions & 19 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ void LlamaTritonModel<T>::handleMissingParams()
(int)model_param_.vocab_size);
}

if (!attn_param_.max_position_embeddings) {
attn_param_.max_position_embeddings = 2048;
if (!attn_param_.rope.max_position_embeddings) {
attn_param_.rope.max_position_embeddings = 2048;
TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.",
(int)attn_param_.max_position_embeddings);
(int)attn_param_.rope.max_position_embeddings);
}

if (!engine_param_.max_batch_size) {
Expand All @@ -153,7 +153,7 @@ void LlamaTritonModel<T>::handleMissingParams()
}

if (!engine_param_.session_len) {
engine_param_.session_len = attn_param_.max_position_embeddings;
engine_param_.session_len = attn_param_.rope.max_position_embeddings;
TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)engine_param_.session_len);
}

Expand Down Expand Up @@ -277,22 +277,25 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
model_param_.attn_bias = model_reader["attn_bias"].as<int>(0);
model_param_.group_size = model_reader["group_size"].as<int>(0);

attn_param_.softmax_scale = attention_reader["softmax_scale"].as<float>(0);
attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as<int>(0);
// rotary embedding parameters
attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as<int>();
attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as<float>(10000.0f);
attn_param_.softmax_scale = attention_reader["softmax_scale"].as<float>(0);
attn_param_.attention_factor = attention_reader["attention_factor"].as<float>(-1.f);
attn_param_.beta_fast = attention_reader["beta_fast"].as<float>(32.f);
attn_param_.beta_slow = attention_reader["beta_slow"].as<float>(1.f);
attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as<std::string>("");
attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as<float>(0.f);
attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as<float>(1.0);
attn_param_.high_freq_factor = attention_reader["high_freq_factor"].as<float>(1.0);
attn_param_.max_position_embeddings = attention_reader["max_position_embeddings"].as<int>(0);
attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as<int>(0);
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as<int>(0);

attn_param_.original_max_position_embeddings = attention_reader["original_max_position_embeddings"].as<int>(0);
attn_param_.rope.type = GetRoPEType(attention_reader["rope_scaling_type"].as<std::string>(""));
attn_param_.rope.dim = attention_reader["rotary_embedding"].as<int>();
attn_param_.rope.base = attention_reader["rope_theta"].as<float>(10000.0f);
attn_param_.rope.max_position_embeddings = attention_reader["max_position_embeddings"].as<int>(0);
attn_param_.rope.factor = attention_reader["rope_scaling_factor"].as<float>(0.f);
if (attn_param_.rope.type == RotaryScalingType::kYarn) {
attn_param_.rope.yarn.attention_factor = attention_reader["attention_factor"].as<float>(-1.f);
attn_param_.rope.yarn.beta_fast = attention_reader["beta_fast"].as<float>(32.f);
attn_param_.rope.yarn.beta_slow = attention_reader["beta_slow"].as<float>(1.f);
}
else if (attn_param_.rope.type == RotaryScalingType::kLlama3) {
attn_param_.rope.llama3.low_freq_factor = attention_reader["low_freq_factor"].as<float>(1.0);
attn_param_.rope.llama3.high_freq_factor = attention_reader["high_freq_factor"].as<float>(1.0);
attn_param_.rope.llama3.original_max_position_embeddings =
attention_reader["original_max_position_embeddings"].as<int>(0);
}

engine_param_.max_batch_size = engine_reader["max_batch_size"].as<int>(0);
engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as<int>(0);
Expand Down
Loading