Skip to content

Commit

Permalink
[ Speculative decoding ] Split assistant generation and speculative d…
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode authored Oct 28, 2024
1 parent c30c4fe commit f617099
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
// Speculative decoding generation parameters are mutually excluded
// Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded
// add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration
config.num_assistant_tokens = 5;
// add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def main():

config = openvino_genai.GenerationConfig()
config.max_new_tokens = 100
# Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded
# add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration
config.num_assistant_tokens = 5
# add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold`
# config.assistant_confidence_threshold = 0.4

# Since the streamer is set, the results will be printed
# every time a new token is generated and put into the streamer queue.
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ bool GenerationConfig::is_multinomial() const {
}

bool GenerationConfig::is_speculative_decoding() const {
return assistant_confidence_threshold > 0 || num_assistant_tokens > 0;
return (assistant_confidence_threshold > 0 || num_assistant_tokens > 0);
}

void GenerationConfig::validate() const {
Expand Down
65 changes: 44 additions & 21 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,26 +652,48 @@ align_all_sequence_len(SequenceGroup::Ptr& sequence_group,
logit_processor.update_generated_len(min_generated_tokens);
}

bool
validate_candidate(Sequence::Ptr running_sequence,
size_t& token_idx,
Token& sampled_token,
bool& is_extend_sequence,
size_t& max_removed_tokens) {
if (token_idx > 0) {
const auto& generated_tokens = running_sequence->get_generated_ids();
auto it = generated_tokens.rbegin();
std::advance(it, token_idx - 1);
// to validate candidates from assisting model and remove incorrect ones from generated sequence
if (*it != sampled_token.m_index) {
running_sequence->remove_last_tokens(token_idx);
max_removed_tokens = std::max(max_removed_tokens, token_idx);
is_extend_sequence = true;
return false;
} else {
sampled_token.m_index = *it;
}
bool Sampler::validate_candidate(
Sequence::Ptr running_sequence,
size_t& token_idx,
Token& sampled_token,
bool& is_extend_sequence,
size_t& max_removed_tokens,
bool do_sample) {
OPENVINO_ASSERT(token_idx > 0);
const auto& generated_tokens = running_sequence->get_generated_ids();
auto it_token_id = generated_tokens.rbegin();
std::advance(it_token_id, token_idx - 1);

bool is_candidate_accepted = false;
// first tokens in case of speculative decoding should be generated by main model
if (do_sample &&
running_sequence->get_generated_len() != running_sequence->get_sequence_group_ptr()->get_num_tokens_to_validate()) {
const auto& generated_log_probs = running_sequence->get_generated_log_probs();
auto it_log_prob = generated_log_probs.rbegin();
std::advance(it_log_prob, token_idx - 1);

float p_i = std::exp(*it_log_prob),
q_i = std::exp(sampled_token.m_log_prob),
probability_ratio = p_i / q_i;

auto dist = std::uniform_int_distribution<>(0, 100); // equivalent to multinomial with number of trials == 1
float r_i = dist(rng_engine);
r_i /= 100;
is_candidate_accepted = r_i <= probability_ratio;
} else {
is_candidate_accepted = *it_token_id == sampled_token.m_index;
}

// to validate candidates from assisting model and remove incorrect ones from generated sequence
if (!is_candidate_accepted) {
running_sequence->remove_last_tokens(token_idx);
max_removed_tokens = std::max(max_removed_tokens, token_idx);
is_extend_sequence = true;
return false;
} else {
sampled_token.m_index = *it_token_id;
}

return true;

}
Expand Down Expand Up @@ -759,8 +781,9 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// flag to add sampled token to generated sequence or extend logit processors only
bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens,
is_validation_passed = true;
if (is_validation_mode_enabled && !is_generate_n_tokens) {
is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token_id, is_extend_sequence, max_removed_tokens_per_request);
if (is_validation_mode_enabled && !is_extend_sequence) {
is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token_id,
is_extend_sequence, max_removed_tokens_per_request, sampling_params.do_sample);
// update log prob just while validation process
if (!is_extend_sequence) {
OPENVINO_ASSERT(generated_and_verified_len < running_sequences[running_sequence_id]->get_generated_len());
Expand Down
3 changes: 3 additions & 0 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class Sampler {
std::vector<Token> _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence);
std::vector<int64_t> _try_finish_generation(SequenceGroup::Ptr & sequence_group);

bool validate_candidate(Sequence::Ptr running_sequence, size_t& token_idx, Token& sampled_token,
bool& is_extend_sequence, size_t& max_removed_tokens, bool do_sample);

// request ID => beam search tracking information
std::map<uint64_t, GroupBeamSearcher> m_beam_search_info;

Expand Down
1 change: 0 additions & 1 deletion src/python/openvino_genai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,4 @@
Generator,
CppStdGenerator,
draft_model

)

0 comments on commit f617099

Please sign in to comment.