Skip to content

Commit

Permalink
Merge branch 'master' into cb-by-default
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov authored Jan 4, 2025
2 parents 2f99472 + b4d0d3c commit 11f9714
Show file tree
Hide file tree
Showing 29 changed files with 709 additions and 771 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/mac.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: macOS (12, Python 3.9)
name: macOS (12, Python 3.10)
on:
workflow_dispatch:
pull_request:
Expand All @@ -16,7 +16,7 @@ concurrency:
cancel-in-progress: true

env:
PYTHON_VERSION: '3.9'
PYTHON_VERSION: '3.10'
OV_BRANCH: master
OV_TARBALL: ''

Expand Down
2 changes: 1 addition & 1 deletion samples/deployment-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
--extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
openvino_genai~=2025.0.0.0.dev
librosa==0.10.2.post1 # For Whisper
pillow==11.0.0 # Image processing for VLMs
pillow==11.1.0 # Image processing for VLMs
14 changes: 7 additions & 7 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);
utils::apply_gather_before_matmul_transformation(model);

initialize_pipeline(model, scheduler_config, properties, device_config, core);
}
Expand Down Expand Up @@ -444,7 +445,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(
const float * logits_data = logits.data<float>();
ov::Shape logits_shape = logits.get_shape();
OPENVINO_ASSERT(logits_shape.size() == 3);
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];
size_t vocab_size = logits_shape[2];
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
// requests not scheduled, in decoding phase or not echoing are not processed
Expand All @@ -454,26 +455,25 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(

size_t num_running_sequences = sequence_group->num_running_seqs();
OPENVINO_ASSERT(num_running_sequences == 1);
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens();
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
size_t output_seq_len = sequence_group->get_output_seq_len();

const float * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;

size_t num_prompt_tokens_processed = sequence_group->get_num_processed_tokens();
OPENVINO_ASSERT(num_prompt_tokens_processed + actual_seq_len <= sequence_group->get_prompt_len());
OPENVINO_ASSERT(num_prompt_tokens_processed + output_seq_len <= sequence_group->get_prompt_len());

// if we processed the whole prompt we don't include last logprob as it will be processed by the sampler (it's already completion)
// otherwise we include it as it will be used in the next part of the prompt
int exclude_last_logprob = 1;
if (num_prompt_tokens_processed + actual_seq_len < sequence_group->get_prompt_len())
if (num_prompt_tokens_processed + output_seq_len < sequence_group->get_prompt_len())
exclude_last_logprob = 0;

// if we start processing the prompt we add "fake" log prob for the first position (begin of sequence)
if (num_prompt_tokens_processed == 0)
sequence_group->append_prompt_log_prob(1.0);

for (int token_logits_offset = 0, token_id_offset = num_prompt_tokens_processed + 1;
token_logits_offset < actual_seq_len - exclude_last_logprob;
token_logits_offset < output_seq_len - exclude_last_logprob;
token_logits_offset++, token_id_offset++) {

const float* token_logits = (sequence_group_logits_data + token_logits_offset * vocab_size);
Expand All @@ -498,7 +498,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(

sequence_group->append_prompt_log_prob(token_logit - max_value - log_sum);
}
currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
currently_processed_tokens += output_seq_len * num_running_sequences;
// For max_new_tokens == 0, we don't reach sampling so need to notify handle separately
if(sequence_group->get_sampling_parameters().max_new_tokens == 0) {
sequence_group->notify_handle_echo_only();
Expand Down
14 changes: 7 additions & 7 deletions src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ void GenerationConfig::validate() const {
OPENVINO_ASSERT(temperature > 0, "When 'do_sample' is true, temperature must be a strictly positive float, but got ", temperature);
} else {
// parameters requiring multinomial
OPENVINO_ASSERT(top_k == std::numeric_limits<size_t>::max(), "When 'do_sample' is false, top_k must be max of size_t, but got ", top_k);
OPENVINO_ASSERT(top_p == 1.0f, "When 'do_sample' is false, top_p must be 1.0f, but got ", top_p);
OPENVINO_ASSERT(temperature == 1.0f, "When 'do_sample' is false, temperature must be a 1.0f, but got ", temperature);
// OPENVINO_ASSERT(top_k == std::numeric_limits<size_t>::max(), "When 'do_sample' is false, top_k must be max of size_t, but got ", top_k);
// OPENVINO_ASSERT(top_p == 1.0f, "When 'do_sample' is false, top_p must be 1.0f, but got ", top_p);
// OPENVINO_ASSERT(temperature == 1.0f, "When 'do_sample' is false, temperature must be a 1.0f, but got ", temperature);
}

if (is_beam_search()) {
Expand All @@ -252,10 +252,10 @@ void GenerationConfig::validate() const {
}
} else {
// parameters requiring beam search
OPENVINO_ASSERT(num_beam_groups == 1, "'num_beam_groups' is supported by beam search only and should be 1 otherwise, but got ", num_beam_groups);
OPENVINO_ASSERT(no_repeat_ngram_size == std::numeric_limits<size_t>::max(), "'no_repeat_ngram_size' is supported only by beam search, otherwise should be set to max of size_t, but got ", no_repeat_ngram_size);
OPENVINO_ASSERT(diversity_penalty == 0.0f, "'diversity_penalty' is set to ", diversity_penalty, " (default is 0.0f), which is supported only by beam search sampling");
OPENVINO_ASSERT(length_penalty == 1.0f, "'length_penalty' is set to ", length_penalty, " (default is 1.0f), which is supported only by beam search sampling");
// OPENVINO_ASSERT(num_beam_groups == 1, "'num_beam_groups' is supported by beam search only and should be 1 otherwise, but got ", num_beam_groups);
// OPENVINO_ASSERT(no_repeat_ngram_size == std::numeric_limits<size_t>::max(), "'no_repeat_ngram_size' is supported only by beam search, otherwise should be set to max of size_t, but got ", no_repeat_ngram_size);
// OPENVINO_ASSERT(diversity_penalty == 0.0f, "'diversity_penalty' is set to ", diversity_penalty, " (default is 0.0f), which is supported only by beam search sampling");
// OPENVINO_ASSERT(length_penalty == 1.0f, "'length_penalty' is set to ", length_penalty, " (default is 1.0f), which is supported only by beam search sampling");
}

// assistant generation
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config)
: LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
utils::slice_matmul_stateful_model(model);
utils::apply_slice_before_matmul_transformation(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

ov::CompiledModel compiled_model;
Expand Down
96 changes: 67 additions & 29 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "llm_pipeline_static.hpp"

#include "sampler.hpp"

#include <fstream>
#include <regex>

Expand Down Expand Up @@ -235,12 +237,12 @@ enum class GenerateHint {

std::string to_string(GenerateHint h) {
switch(h) {
case GenerateHint::FAST_COMPILE :
case GenerateHint::FAST_COMPILE :
return "FAST_COMPILE";
case GenerateHint::BEST_PERF :
case GenerateHint::BEST_PERF :
return "BEST_PERF";
default:
OPENVINO_THROW("Unsupported value for type GenerateHint provided");
OPENVINO_THROW("Unsupported value for type GenerateHint provided");
}
}

Expand Down Expand Up @@ -632,6 +634,19 @@ void copy_columns_by_row_chunks(const ov::Tensor& src, ov::Tensor& dst) {
}
}

void stream_generated_tokens(std::shared_ptr<ov::genai::StreamerBase> streamer_ptr,
ov::genai::GenerationHandle& handle) {
if (streamer_ptr && handle->can_read()) {
std::unordered_map<uint64_t, ov::genai::GenerationOutput> token = handle->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
handle->drop();
break;
}
}
}
}

} // anonymous namespace

namespace ov {
Expand All @@ -643,7 +658,8 @@ StaticLLMPipeline::StaticLLMPipeline(
const std::string& device,
const ov::AnyMap& config
) : LLMPipelineImplBase(tokenizer,
utils::from_config_json_if_exists(models_path)) {
utils::from_config_json_if_exists(models_path)),
m_sampler(m_tokenizer) {
auto properties = config;
/* NB: Static LLM pipeline consists of two models,
first to process the input prompt (prefill),
Expand Down Expand Up @@ -672,6 +688,8 @@ StaticLLMPipeline::StaticLLMPipeline(
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler.set_seed(m_generation_config.rng_seed);
};

StaticLLMPipeline::StaticLLMPipeline(
Expand All @@ -688,8 +706,7 @@ StaticLLMPipeline::StaticLLMPipeline(
const std::string& device,
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config
) : LLMPipelineImplBase(tokenizer, generation_config) {

) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
bool use_blobs = false;
auto anyopt = get_option<bool>(properties, "USE_BLOBS");
if (anyopt.has_value()) {
Expand All @@ -708,6 +725,8 @@ StaticLLMPipeline::StaticLLMPipeline(
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler.set_seed(m_generation_config.rng_seed);
}

void StaticLLMPipeline::setupAndCompileModels(
Expand Down Expand Up @@ -955,7 +974,10 @@ EncodedResults StaticLLMPipeline::generate(
attention_mask = data->attention_mask;
}

if (input_ids.get_shape().at(0) > 1u) {
ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];

if (batch_size > 1u) {
OPENVINO_THROW("Currently only batch size=1 is supported");
}

Expand All @@ -974,12 +996,14 @@ EncodedResults StaticLLMPipeline::generate(
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

if (!config.is_greedy_decoding()) {
OPENVINO_THROW("Currently only greedy decoding is supported");
if (!config.is_greedy_decoding() && !config.is_multinomial()) {
OPENVINO_THROW("Currently only greedy and multinomial decoding are supported");
}

if (config.num_return_sequences != 1u) {
OPENVINO_THROW("Currently only \"num_return_sequences\" equal to 1 is supported!");
}

ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];
ov::genai::EncodedResults results;
auto& raw_perf_counters = results.perf_metrics.raw_metrics;
// NB: Only batch=1 is supported now
Expand Down Expand Up @@ -1016,11 +1040,21 @@ EncodedResults StaticLLMPipeline::generate(

// NB: Now there are prompt_len tokens in KV-cache
m_kvcache_desc.num_stored_tokens += static_cast<uint32_t>(prompt_len);
int64_t last_token = utils::argmax(m_prefill_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);
if (streamer_ptr && streamer_ptr->put(last_token)) {
return results;
}

auto logits = m_prefill_request.get_tensor("logits");
int64_t output_sequence_len = logits.get_shape().at(1);

auto sequence_group = std::make_shared<SequenceGroup>(
0 /* request_id */, padded_input_ids, config, 1 /* block_size */);
sequence_group->update_processed_tokens_num(m_kvcache_desc.max_prompt_size - output_sequence_len);
sequence_group->schedule_tokens(output_sequence_len);

// NB: Controls what tokens are ready to be pushed into the streamer
GenerationHandle handle = std::make_shared<GenerationHandleImpl>(
sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters());

SamplerOutput sampler_output = m_sampler.sample({sequence_group}, logits);
stream_generated_tokens(streamer_ptr, handle);

// Outputs: logits, ...
const auto kStartOutputKVCacheLayers = 1u;
Expand Down Expand Up @@ -1061,30 +1095,28 @@ EncodedResults StaticLLMPipeline::generate(
std::fill(attention_mask_data, attention_mask_data + m_kvcache_desc.num_stored_tokens - 1u, 1u);
attention_mask_data[m_kvcache_desc.total_size - 1] = 1u;

const size_t max_tokens = config.get_max_new_tokens(prompt_len);
for (int i = 0; i < max_tokens - 1; ++i) {
input_ids_data[0] = last_token;
while (sequence_group->is_running()) {
sequence_group->schedule_tokens(1);
const auto running_sequences = sequence_group->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() == 1u);

input_ids_data[0] = running_sequences.front()->get_generated_ids().back();
position_ids_data[0] = m_kvcache_desc.num_stored_tokens;
attention_mask_data[m_kvcache_desc.num_stored_tokens - 1] = 1u;

m_kvcache_request.infer();
m_kvcache_desc.num_stored_tokens += 1;

last_token = utils::argmax(m_kvcache_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);

raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now());
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);
if (streamer_ptr && streamer_ptr->put(last_token)) {
break;
}

if (last_token == config.eos_token_id && !config.ignore_eos) {
break;
}
SamplerOutput sampler_output = m_sampler.sample(
{sequence_group}, m_kvcache_request.get_tensor("logits"));
stream_generated_tokens(streamer_ptr, handle);

// NB: KV-cache is full, further generation is impossible
if (m_kvcache_desc.num_stored_tokens == m_kvcache_desc.total_size) {
sequence_group->set_out_of_memory();
break;
}

Expand All @@ -1108,6 +1140,12 @@ EncodedResults StaticLLMPipeline::generate(
streamer_ptr->end();
}

OPENVINO_ASSERT(sequence_group->get_finished_sequences().size() == 1u);
auto sequence = sequence_group->get_finished_sequences().front();
results.tokens[0] = sequence->get_generated_ids();
results.scores[0] = sequence->get_cumulative_log_prob();
m_sampler.clear_request_info(sequence_group->get_request_id());

auto stop_time = std::chrono::steady_clock::now();
// If is called without tokenization then that stat will not be reported.
auto& metrics = results.perf_metrics;
Expand Down
5 changes: 4 additions & 1 deletion src/cpp/src/llm_pipeline_static.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <filesystem>

#include "llm_pipeline_base.hpp"
#include "sampler.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -77,6 +78,8 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
bool v_tensors_transposed;
};

Sampler m_sampler;

KVCacheDesc m_kvcache_desc;
ov::InferRequest m_kvcache_request;
ov::InferRequest m_prefill_request;
Expand Down
Loading

0 comments on commit 11f9714

Please sign in to comment.