Skip to content

Commit

Permalink
[CB] Simplify SequenceGroup API (#1456)
Browse files Browse the repository at this point in the history
- Removed `enable_prefix_caching` parameter from `SequenceGroup` ctor
- Removed necessity to call `set_sequence_group_ptr` after creation of
sequence group
- Renamed `get_cumulative_log_probs` to `get_cumulative_log_prob` as it
returns a floating point value
  • Loading branch information
ilya-lavrenov authored Dec 30, 2024
1 parent 4be813e commit 71dc893
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 150 deletions.
6 changes: 2 additions & 4 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request

SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, input_ids,
sampling_params,
m_scheduler->get_block_size(),
m_scheduler->get_config().enable_prefix_caching);
sequence_group->set_sequence_group_ptr(sequence_group);
m_scheduler->get_block_size());

if (m_scheduler->get_config().enable_prefix_caching) {
m_scheduler->restore_cached_blocks(sequence_group);
Expand Down Expand Up @@ -353,7 +351,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs();
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob();
const auto & generated_ids = sequence->get_generated_ids();

if (sampling_params.echo)
Expand Down
6 changes: 2 additions & 4 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,23 +300,21 @@ EncodedResults StatefulLLMPipeline::generate(

std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);

sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size);
}

sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
}

Expand Down
11 changes: 7 additions & 4 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,13 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(

auto logits = m_llm.get_tensor("logits");

int64_t sequence_len = logits.get_shape().at(1);
// since we have applied `Slice` operationto last MatMul, model output sequence lenght is 1
// so, we need to update sequence groups to think that they already have processed all prompt tokens except last ones
// and schedule only `output_sequence_len` ones
int64_t output_sequence_len = logits.get_shape().at(1);
for (auto& sequence_group : sequence_groups) {
sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - sequence_len);
sequence_group->schedule_tokens(sequence_len);
sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - output_sequence_len);
sequence_group->schedule_tokens(output_sequence_len);
}

std::map<size_t, size_t> beam_offets;
Expand Down Expand Up @@ -217,7 +220,7 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(

for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) {
const auto & sequence = sequences[seq_id];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs();
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob();

results.tokens.push_back(sequence->get_generated_ids());
results.scores.push_back(score);
Expand Down
89 changes: 41 additions & 48 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#pragma once

#include <vector>
#include <cassert>
#include <set>
#include <cstdlib>
#include <string_view>
#include <memory>

#include "openvino/genai/generation_handle.hpp"
#include "openvino/genai/generation_config.hpp"
Expand Down Expand Up @@ -40,32 +42,32 @@ class Sequence {
GenerationFinishReason m_finish_reason = GenerationFinishReason::NONE;
float m_cumulative_log_prob = 0.0f;
std::vector<int64_t> m_prefix_hashes;
std::weak_ptr<SequenceGroup> m_sequence_group;
SequenceGroup* m_sequence_group = nullptr;
static std::mutex m_counter_mutex;

size_t _make_hash(size_t content_length);
public:
using Ptr = std::shared_ptr<Sequence>;
using CPtr = std::shared_ptr<const Sequence>;

// don't use directly
Sequence(const uint64_t id) : m_grouped_id(id) {};
explicit Sequence(const uint64_t id) : m_grouped_id(id) {}

// don't use directly
Sequence(const Sequence& seq, const uint64_t id) :
m_generated_ids(seq.m_generated_ids),
m_grouped_id(id),
m_status(seq.m_status),
m_cumulative_log_prob(seq.m_cumulative_log_prob){
m_cumulative_log_prob(seq.m_cumulative_log_prob),
m_sequence_group(seq.m_sequence_group) {
OPENVINO_ASSERT(seq.m_id != m_id);
}

public:
using Ptr = std::shared_ptr<Sequence>;
using CPtr = std::shared_ptr<const Sequence>;

static Sequence::Ptr create(const uint64_t id) {
return std::make_shared<Sequence>(id);
return Sequence::Ptr(new Sequence(id));
}

static Sequence::Ptr fork(Sequence::CPtr sequence, const uint64_t id) {
return std::make_shared<Sequence>(*sequence, id);
return Sequence::Ptr(new Sequence(*sequence, id));
}

bool operator ==(const Sequence& other) const {
Expand Down Expand Up @@ -130,7 +132,7 @@ class Sequence {
GenerationOutput output;
if (token_cnt > 0) {
OPENVINO_ASSERT(m_generated_ids.size());
output.score = get_cumulative_log_probs();
output.score = get_cumulative_log_prob();

auto generated_token_id = get_generated_ids();
auto generated_log_probs = get_generated_log_probs();
Expand Down Expand Up @@ -163,7 +165,7 @@ class Sequence {
return m_generated_log_probs;
}

float get_cumulative_log_probs() const {
float get_cumulative_log_prob() const {
return m_cumulative_log_prob;
}

Expand All @@ -173,20 +175,18 @@ class Sequence {
}

float get_beam_search_score(const ov::genai::GenerationConfig& sampling_params) const {
float cumulative_log_prob = get_cumulative_log_probs(), current_length = get_generated_len();
float cumulative_log_prob = get_cumulative_log_prob(), current_length = get_generated_len();
float score = cumulative_log_prob / std::pow(current_length, sampling_params.length_penalty);
return score;
}

// Each KV block can be uniquely identified by
void set_sequence_group_ptr(std::shared_ptr<SequenceGroup> sequence_group) {
void set_sequence_group_ptr(SequenceGroup* sequence_group) {
assert(sequence_group != nullptr);
m_sequence_group = sequence_group;
}

std::shared_ptr<SequenceGroup> get_sequence_group_ptr() const {
OPENVINO_ASSERT(!m_sequence_group.expired());
return m_sequence_group.lock();
}
std::shared_ptr<SequenceGroup> get_sequence_group_ptr() const;

// Each KV block can be uniquely identified by
// the tokens within the block and the tokens in the prefix before the block.
Expand All @@ -198,15 +198,14 @@ class Sequence {
// - each sequence shares the same prompt and KV-caches for promp
// - in case of beam search each sequence also shares specific part of generic phase
// via reference counter mechanism on BlockManager level
class SequenceGroup {
class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
uint64_t m_request_id;
std::vector<Sequence::Ptr> m_sequences;
ov::genai::GenerationConfig m_sampling_params;
std::size_t m_block_size;
TokenIds m_prompt_ids;
std::vector<float> m_prompt_log_probs;
GenerationStream::Ptr m_generation_stream;
bool m_enable_prefix_caching;
size_t m_num_evicted_tokens = 0;
bool m_has_echoed = false;

Expand All @@ -226,33 +225,32 @@ class SequenceGroup {

size_t m_num_streamed_tokens = 0, m_stream_window_size = 0;


SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size)
: m_request_id(request_id),
m_sampling_params(sampling_params),
m_block_size(block_size),
m_enable_prefix_caching(enable_prefix_caching) {
m_generation_stream = GenerationStream::create();
}
m_generation_stream(GenerationStream::create()) { }

public:
using Ptr = std::shared_ptr<SequenceGroup>;
using CPtr = std::shared_ptr<const SequenceGroup>;

SequenceGroup(uint64_t request_id, const TokenIds& input_ids, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
: SequenceGroup(request_id, ov::Tensor(ov::element::i64, ov::Shape{input_ids.size()}, (void *)input_ids.data()), sampling_params, block_size, enable_prefix_caching) {
SequenceGroup(uint64_t request_id, const TokenIds& input_ids, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size)
: SequenceGroup(request_id, ov::Tensor(ov::element::i64, ov::Shape{input_ids.size()}, (void *)input_ids.data()), sampling_params, block_size) {
}

SequenceGroup(uint64_t request_id, const ov::Tensor input_ids, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
: SequenceGroup(request_id, sampling_params, block_size, enable_prefix_caching) {
add_sequence(Sequence::create(m_next_sequence_id++));

SequenceGroup(uint64_t request_id, const ov::Tensor input_ids, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size)
: SequenceGroup(request_id, sampling_params, block_size) {
m_prompt_ids.resize(input_ids.get_size());
std::copy_n(input_ids.data<int64_t>(), input_ids.get_size(), m_prompt_ids.begin());
m_prompt_log_probs.reserve(m_prompt_ids.size());

// create a single sequence
add_sequence(Sequence::create(m_next_sequence_id++));
}

void add_sequence(const Sequence::Ptr & sequence) {
sequence->set_sequence_group_ptr(this);
m_sequences.emplace_back(sequence);
}

Expand Down Expand Up @@ -322,7 +320,6 @@ class SequenceGroup {
return it != m_sequences.end();
}


/**
* @param seq_id Sequence identifier
* @return Pointer to the sequence with this ID.
Expand All @@ -344,8 +341,8 @@ class SequenceGroup {

std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) -> bool {
bool is_beam_search = m_sampling_params.is_beam_search();
const float score_1 = is_beam_search ? s1->get_beam_search_score(m_sampling_params) : s1->get_cumulative_log_probs();
const float score_2 = is_beam_search ? s2->get_beam_search_score(m_sampling_params) : s2->get_cumulative_log_probs();
const float score_1 = is_beam_search ? s1->get_beam_search_score(m_sampling_params) : s1->get_cumulative_log_prob();
const float score_2 = is_beam_search ? s2->get_beam_search_score(m_sampling_params) : s2->get_cumulative_log_prob();
return score_1 > score_2;
});

Expand Down Expand Up @@ -409,7 +406,6 @@ class SequenceGroup {
m_num_evicted_tokens += num_evicted_tokens;
}


/**
* Resets the eviction tracking on this sequence to the state prior to any eviction taking place.
*/
Expand All @@ -434,7 +430,6 @@ class SequenceGroup {
return get_num_processed_tokens() + get_num_scheduled_tokens();
}


bool requires_sampling() const {
return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.max_new_tokens > 0;
}
Expand Down Expand Up @@ -513,7 +508,6 @@ class SequenceGroup {
return (get_context_len() - get_num_evicted_tokens() + m_block_size - 1) / m_block_size;
}


// requires number of physical blocks for next generation
size_t get_num_blocks() const {
return get_num_logical_blocks();
Expand All @@ -524,10 +518,9 @@ class SequenceGroup {
}

Sequence::Ptr fork_sequence(Sequence::CPtr sequence) {
auto ptr = sequence->get_sequence_group_ptr();
m_sequences.emplace_back(Sequence::fork(std::move(sequence), m_next_sequence_id++));
set_sequence_group_ptr(ptr);
return m_sequences.back();
auto forked_sequence = Sequence::fork(sequence, m_next_sequence_id++);
m_sequences.emplace_back(forked_sequence);
return forked_sequence;
}

const ov::genai::GenerationConfig& get_sampling_parameters() const {
Expand Down Expand Up @@ -568,12 +561,6 @@ class SequenceGroup {
return m_is_gen_paused;
}

void set_sequence_group_ptr(std::shared_ptr<SequenceGroup> sequence_group) {
for (auto sequence: m_sequences) {
sequence->set_sequence_group_ptr(sequence_group);
}
}

GenerationStream::Ptr get_generation_stream() {
return m_generation_stream;
}
Expand All @@ -600,7 +587,7 @@ class SequenceGroup {
output.generated_ids.insert(output.generated_ids.begin(), m_prompt_ids.begin(), m_prompt_ids.end());
output.generated_log_probs.insert(output.generated_log_probs.begin(), m_prompt_log_probs.begin(), m_prompt_log_probs.end());
}
output.score = m_sampling_params.is_beam_search() ? sequence->get_beam_search_score(m_sampling_params) : sequence->get_cumulative_log_probs();
output.score = m_sampling_params.is_beam_search() ? sequence->get_beam_search_score(m_sampling_params) : sequence->get_cumulative_log_prob();
output.finish_reason = sequence->get_finish_reason();
outputs.emplace(sequence->get_grouped_id(), output);
}
Expand Down Expand Up @@ -684,4 +671,10 @@ class SequenceGroup {
m_generation_stream->push(std::move(outputs));
}
};

inline std::shared_ptr<SequenceGroup> Sequence::get_sequence_group_ptr() const {
assert(m_sequence_group != nullptr);
return m_sequence_group->shared_from_this();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ init_request(
for (const auto& candidate_sequence : candidates) {
Sequence::Ptr sequence;
if (is_init_all_sequences_in_request && candidate_sequence.first > 0) {
sequence = Sequence::Ptr(new Sequence(candidate_sequence.first));
sequence = Sequence::create(candidate_sequence.first);
sequence->set_status(ov::genai::SequenceStatus::RUNNING);
request->add_sequence(sequence);
} else {
Expand Down
4 changes: 1 addition & 3 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
std::vector<SequenceGroup::Ptr> requests;
size_t request_id = 0;
size_t block_size = 1; // not used
bool enable_prefix_caching = false;

size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1) - to_remove_from_hist;
size_t inputs_embeds_size = inputs_embeds.get_shape().at(1);
Expand All @@ -185,8 +184,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
std::fill_n(prompt_ids.data<int64_t>(), prompt_ids.get_size(), m_tokenizer.get_pad_token_id());
std::copy(tokenized_history.begin(), tokenized_history.end(), prompt_ids.data<int64_t>());

SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, prompt_ids, generation_config, block_size, enable_prefix_caching);
sequence_group->set_sequence_group_ptr(sequence_group);
SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, prompt_ids, generation_config, block_size);
requests.push_back(sequence_group);

std::shared_ptr<StreamerBase> streamer_ptr = std::visit(overloaded{
Expand Down
17 changes: 6 additions & 11 deletions tests/cpp/block_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ TEST(TestBlockManager, general_test) {
ov::genai::TokenIds prompt_ids;

ov::genai::SequenceGroup::Ptr sequence_group = std::make_shared<ov::genai::SequenceGroup>(
0,
0,
ov::Tensor(ov::element::i64, {
prompt_ids.size()}, prompt_ids.data()),
ov::genai::beam_search(),
4,
false);
4);
auto sequence = sequence_group->get_not_finished_sequences()[0];
bm.allocate(sequence, 6);
auto seq_id = sequence->get_id();
Expand Down Expand Up @@ -46,13 +45,11 @@ TEST(TestBlockManager, required_blocks_count) {

std::vector<uint64_t> tokens = {0,1,2,3,4};
ov::genai::SequenceGroup::Ptr sequence_group = std::make_shared<ov::genai::SequenceGroup>(
0,
0,
ov::Tensor(ov::element::i64, {
tokens.size()}, tokens.data()),
ov::genai::beam_search(),
4,
false);
sequence_group->set_sequence_group_ptr(sequence_group);
4);
sequence_group->schedule_tokens(5);
auto required_blocks = bm.required_blocks_count(sequence_group);
EXPECT_EQ(required_blocks, 2);
Expand All @@ -62,7 +59,7 @@ TEST(TestBlockManager, required_blocks_count) {
EXPECT_EQ(bm.get_number_of_blocks_occupied_by_sequence(sequence_group), 2);

sequence_group->finish_iteration();
auto sequence_to_fork = sequence_group->get_running_sequences()[0];
auto sequence_to_fork = sequence_group->get_running_sequences()[0];
for (size_t i = 0; i < 4; ++i) {
const auto forked_sequence = sequence_group->fork_sequence(sequence_to_fork);
bm.fork_sequence(sequence_to_fork->get_id(), forked_sequence->get_id());
Expand Down Expand Up @@ -98,9 +95,7 @@ TEST(TestBlockManager, CanFreeBlocksFromSequence) {
ov::Tensor(ov::element::i64, {
tokens.size()}, tokens.data()),
ov::genai::beam_search(),
BLOCK_SIZE,
false);
sequence_group->set_sequence_group_ptr(sequence_group);
BLOCK_SIZE);
sequence_group->schedule_tokens(5);
bm.append_slots(sequence_group);
ASSERT_EQ(bm.num_free_blocks(), 5);
Expand Down
Loading

0 comments on commit 71dc893

Please sign in to comment.