Skip to content

Commit

Permalink
Speculative decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Aug 30, 2024
1 parent 36cb20a commit a538e24
Show file tree
Hide file tree
Showing 13 changed files with 899 additions and 33 deletions.
4 changes: 3 additions & 1 deletion samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ add_subdirectory(cpp/beam_search_causal_lm)
add_subdirectory(cpp/chat_sample)
add_subdirectory(cpp/continuous_batching_accuracy)
add_subdirectory(cpp/continuous_batching_benchmark)
# add_subdirectory(cpp/continuous_batching_speculative_decoding)
# todo: iefode
# add_subdirectory(cpp/continuous_batching_prompt_lookup)
add_subdirectory(cpp/continuous_batching_speculative_decoding)
add_subdirectory(cpp/greedy_causal_lm)
add_subdirectory(cpp/multinomial_causal_lm)
add_subdirectory(cpp/prompt_lookup_decoding_lm)
Expand Down
25 changes: 25 additions & 0 deletions samples/cpp/continuous_batching_prompt_lookup/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# start of dependencies

include(FetchContent)

FetchContent_Declare(cxxopts
URL https://github.com/jarro2783/cxxopts/archive/refs/tags/v3.1.1.tar.gz
URL_HASH SHA256=523175f792eb0ff04f9e653c90746c12655f10cb70f1d5e6d6d9491420298a08)

FetchContent_Declare(nlohmann_json
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz
URL_HASH SHA256=0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406)

FetchContent_MakeAvailable(cxxopts)
FetchContent_MakeAvailable(nlohmann_json)

find_package(OpenVINO REQUIRED COMPONENTS Runtime)

# end of dependencies

set(TARGET_NAME continuous_batching_prompt_lookup)
add_executable(${TARGET_NAME} ${TARGET_NAME}.cpp "prompt_lookup_pipeline.hpp" "prompt_lookup_pipeline.cpp")
target_link_libraries(${TARGET_NAME} PRIVATE openvino::genai cxxopts::cxxopts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <openvino/openvino.hpp>
#include <cxxopts.hpp>

#include "openvino/genai/generation_config.hpp"

#include "prompt_lookup_pipeline.hpp"

void print_generation_result(const ov::genai::GenerationResult& generation_result) {
for (size_t output_id = 0; output_id < generation_result.m_generation_ids.size(); ++output_id) {
std::cout << "Answer " << output_id << " (" << generation_result.m_scores[output_id] << ") : " << generation_result.m_generation_ids[output_id] << std::endl;
}
}

int main(int argc, char* argv[]) try {
// Command line options

cxxopts::Options options("accuracy_sample", "Help command");

options.add_options()
("n,num_prompts", "A number of prompts", cxxopts::value<size_t>()->default_value("1"))
("dynamic_split_fuse", "Whether to use dynamic split-fuse or vLLM scheduling", cxxopts::value<bool>()->default_value("false"))
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
("k,candidates_number", "candidates_number", cxxopts::value<size_t>()->default_value("5"))
("ngram", "Ngram", cxxopts::value<size_t>()->default_value("5"))
("g,generated_len", "generated_len", cxxopts::value<size_t>()->default_value("30"))
("h,help", "Print usage");

cxxopts::ParseResult result;
try {
result = options.parse(argc, argv);
} catch (const cxxopts::exceptions::exception& e) {
std::cout << e.what() << "\n\n";
std::cout << options.help() << std::endl;
return EXIT_FAILURE;
}

if (result.count("help")) {
std::cout << options.help() << std::endl;
return EXIT_SUCCESS;
}

const size_t num_prompts = result["num_prompts"].as<size_t>();
const bool dynamic_split_fuse = result["dynamic_split_fuse"].as<bool>();
const std::string models_path = result["model"].as<std::string>();
const size_t k = result["candidates_number"].as<size_t>();
const size_t g = result["generated_len"].as<size_t>();
const size_t n = result["ngram"].as<size_t>();

// create dataset

std::vector<std::string> prompt_examples = {
// "What is OpenVINO?",
// "How are you?",
"code: ```for (const auto& a : b) { std::cout << a << std::endl; }```",
"Tell me something about Canada",
"What is OpenVINO?",
};

auto greedy = ov::genai::greedy();
greedy.max_new_tokens = g;

std::vector<ov::genai::GenerationConfig> sampling_params_examples {
// ov::genai::beam_search(),
greedy,
// ov::genai::multinomial(),
};

std::vector<std::string> prompts(num_prompts);
std::vector<ov::genai::GenerationConfig> sampling_params(num_prompts);

for (size_t request_id = 0; request_id < num_prompts; ++request_id) {
prompts[request_id] = prompt_examples[request_id % prompt_examples.size()];
sampling_params[request_id] = sampling_params_examples[request_id % sampling_params_examples.size()];
}

// Perform the inference

ov::genai::SchedulerConfig scheduler_config;
// batch size
scheduler_config.max_num_batched_tokens = 256;
// cache params
scheduler_config.num_kv_blocks = 364;
scheduler_config.block_size = 32;
// mode - vLLM or dynamic_split_fuse
scheduler_config.dynamic_split_fuse = dynamic_split_fuse;
// vLLM specific params
scheduler_config.max_num_seqs = 2;

// It's possible to construct a Tokenizer from a different path.
// If the Tokenizer isn't specified, it's loaded from the same folder.
PromptLookupPipeline pipe(models_path, k, n, ov::genai::Tokenizer{models_path}, scheduler_config, "CPU");
auto start_time = std::chrono::system_clock::now();
std::vector<ov::genai::GenerationResult> generation_results = pipe.generate(prompts, sampling_params);

for (size_t request_id = 0; request_id < generation_results.size(); ++request_id) {
const ov::genai::GenerationResult & generation_result = generation_results[request_id];
std::cout << "Question: " << prompts[request_id] << std::endl;
switch (generation_result.m_status)
{
case ov::genai::GenerationStatus::FINISHED:
print_generation_result(generation_result);
break;
case ov::genai::GenerationStatus::IGNORED:
std::cout << "Request was ignored due to lack of memory." <<std::endl;
if (generation_result.m_generation_ids.size() > 0) {
std::cout << "Partial result:" << std::endl;
print_generation_result(generation_result);
}
break;
case ov::genai::GenerationStatus::DROPPED_BY_PIPELINE:
std::cout << "Request was aborted." <<std::endl;
if (generation_result.m_generation_ids.size() > 0) {
std::cout << "Partial result:" << std::endl;
print_generation_result(generation_result);
}
break;
default:
break;
}
std::cout << std::endl;
}
auto end_time = std::chrono::system_clock::now();
std::chrono::duration<double> duration = end_time - start_time;
std::cout << std::endl;
std::cout << "Duration: " << duration.count() << std::endl;
std::cout << "Infer number: " << pipe.infer_cnt << std::endl;
std::cout << "MAX matches number: " << pipe.max_matches << std::endl;
std::cout << "AVG matches number: " << (float(pipe.avg_matches) / pipe.infer_cnt) << std::endl;
} catch (const std::exception& error) {
std::cerr << error.what() << '\n';
return EXIT_FAILURE;
} catch (...) {
std::cerr << "Non-exception object thrown\n";
return EXIT_FAILURE;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "prompt_lookup_pipeline.hpp"

PromptLookupPipeline::PromptLookupPipeline(const std::string& models_path,
size_t candidates_number,
size_t ngram_size,
const ov::genai::SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config) {
ov::genai::Tokenizer tokenizer(models_path);
PromptLookupPipeline(models_path, candidates_number, max_ngram_size, tokenizer, scheduler_config, device, plugin_config);
};

PromptLookupPipeline::PromptLookupPipeline(const std::string& models_path,
size_t candidates_number,
size_t ngram_size,
const ov::genai::Tokenizer& tokenizer,
const ov::genai::SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config) {
m_tokenizer = tokenizer;
set_k(candidates_number);
max_ngram_size = ngram_size;

model_pipeline = ov::genai::ContinuousBatchingPipeline(models_path, m_tokenizer, scheduler_config, device, plugin_config);
model_pipeline.enable_validation_mode();
}

ov::genai::PipelineMetrics PromptLookupPipeline::get_metrics() const {
return model_pipeline.get_metrics();
}

void PromptLookupPipeline::step() {
std::cout << "=======STEP==================" << std::endl;
bool is_updated = false;
if (is_speculative_mode) {
// predict tokens using prompt
std::cout << "num_candidates: " << candidates_number << std::endl;
for (const auto& whole_input : model_pipeline.get_prompts_with_generated_tokens()) {
auto updated_input = whole_input;
const auto& input_ids = whole_input.token_ids;
const size_t input_length = input_ids.size();
for (int32_t ngram_size = max_ngram_size; ngram_size > 0; ngram_size--) {
std::vector<int64_t> ngram = std::vector<int64_t>{input_ids.cend() - ngram_size, input_ids.cend()};
std::cout << "ngram: " << std::endl;
for (const auto& a : ngram) {
std::cout << a;
}
std::cout << std::endl;

// find ngram match in input_ids
size_t ngram_i = 0;
for (size_t input_i = 0; input_i < input_length - ngram_size; input_i++) {
if (ngram[ngram_i] != input_ids[input_i]) {
ngram_i = 0;
continue;
}
ngram_i++;

if (ngram_i < ngram_size) {
continue;
}

// match found with the end at input_i
size_t avaliable_num_pred = std::min(input_length - (input_i + 1), candidates_number);

// return candidates with length of avaliable_num_pred
std::vector<int64_t> candidate{input_ids.cbegin() + input_i + 1,
input_ids.cbegin() + input_i + 1 + avaliable_num_pred};
updated_input.token_ids = candidate;
updated_input.log_probs = std::vector<float>(candidate.size(), 0);

model_pipeline.update_generated_sequence(updated_input);
break;
}
if (whole_input.token_ids != updated_input.token_ids) {
is_updated = true;
break;
}
}
}

// put candidates to model cache
auto candidate_sequences = model_pipeline.get_generated_sequences();
// todo: remove debug code
for (const auto& s : candidate_sequences) {
std::cout << "ASSISTANT: ";
for (const auto& d : s.token_ids) {
std::cout << d << " ";
}
// std::cout << std::endl;
// for (const auto& d : s.log_probs) {
// std::cout << d << " ";
// }
std::cout << std::endl;
std::cout << decode(s.token_ids) << std::endl;
}
}

const auto gen_seq_before = model_pipeline.get_generated_sequences();

// validate candidates and generate 1 new token
model_pipeline.step();

if (is_speculative_mode && is_updated) {
// todo: remove debug code
for (const auto& s : model_pipeline.get_generated_sequences()) {
std::cout << "MODEL: ";
for (const auto& d : s.token_ids) {
std::cout << d << " ";
}
// std::cout << std::endl;
// for (const auto& d : s.log_probs) {
// std::cout << d << " ";
// }
std::cout << std::endl;
std::cout << decode(s.token_ids) << std::endl;
std::cout << std::endl;
}

// todo: iefode: remove debug prints
for (const auto& gen_seq_after : model_pipeline.get_generated_sequences()) {
const auto& candidate_seq = gen_seq_before[gen_seq_after.request_id];
size_t before_len = candidate_seq.token_ids.size(),
after_len = gen_seq_after.token_ids.size();
size_t dist = is_updated ? (after_len <= before_len ? (before_len - after_len) : candidates_number) : 0;
update_strategy(dist);
}
// ov::genai::ContinuousBatchingPipeline::UpdateSeqResult update_result;
// for (const auto& checked_sequence : checked_sequences) {
// update_result = assisting_pipeline.update_generated_sequence(checked_sequence);
// }

// OPENVINO_ASSERT(candidates_number >= update_result.to_remove);
// if (update_result.to_remove) {
// std::cout << "to_remove: " << update_result.to_remove << std::endl;
// }
// update_strategy(candidates_number - update_result.to_remove);
// std::cout << "=========================" << std::endl;
}
}

void PromptLookupPipeline::update_strategy(size_t num_matches) {
std::cout << "num_matches: " << num_matches << std::endl;
max_matches = std::max(max_matches, num_matches);
avg_matches += num_matches;
if (max_candidates_number == 0) {
return;
}
if (num_matches == candidates_number) {
candidates_number = std::min(candidates_number + 2, max_candidates_number);
} else {
candidates_number = std::max(int64_t(candidates_number) - 1, int64_t(1));
}
}


void PromptLookupPipeline::set_k(size_t new_default_k) {
candidates_number = new_default_k;
max_candidates_number = new_default_k * 2;
is_speculative_mode = candidates_number > 0;
}

bool PromptLookupPipeline::has_non_finished_requests() {
return model_pipeline.has_non_finished_requests();
}


std::vector<ov::genai::GenerationHandle>
PromptLookupPipeline::generate_sequences(
const std::vector<ov::Tensor> prompts,
std::vector<ov::genai::GenerationConfig> sampling_params) {
OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request");
OPENVINO_ASSERT(prompts.size() == sampling_params.size());

std::vector<ov::genai::GenerationHandle> generations, assisting_generations;
for (size_t request_id = 0; request_id < prompts.size(); ++request_id) {
generations.push_back(model_pipeline.add_request(request_id, prompts[request_id], sampling_params[request_id]));
}

while (has_non_finished_requests()) {
step();
infer_cnt++;
}

return generations;
}
Loading

0 comments on commit a538e24

Please sign in to comment.