Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Dec 30, 2024
1 parent 0d92772 commit aeb7dd2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 32 deletions.
55 changes: 27 additions & 28 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <vector>
#include <cassert>
#include <set>
#include <cstdlib>
#include <string_view>
Expand Down Expand Up @@ -41,34 +42,33 @@ class Sequence {
GenerationFinishReason m_finish_reason = GenerationFinishReason::NONE;
float m_cumulative_log_prob = 0.0f;
std::vector<int64_t> m_prefix_hashes;
std::weak_ptr<SequenceGroup> m_sequence_group;
SequenceGroup* m_sequence_group = nullptr;
static std::mutex m_counter_mutex;

size_t _make_hash(size_t content_length);

public:
using Ptr = std::shared_ptr<Sequence>;
using CPtr = std::shared_ptr<const Sequence>;

// don't use directly
Sequence(const uint64_t id) : m_grouped_id(id) {};

// don't use directly
Sequence(const Sequence& seq, const uint64_t id) :
m_generated_ids(seq.m_generated_ids),
m_grouped_id(id),
m_status(seq.m_status),
m_cumulative_log_prob(seq.m_cumulative_log_prob){
m_cumulative_log_prob(seq.m_cumulative_log_prob),
m_sequence_group(seq.m_sequence_group) {
OPENVINO_ASSERT(seq.m_id != m_id);
set_sequence_group_weak_ptr(seq.get_sequence_group_ptr());
}

public:
using Ptr = std::shared_ptr<Sequence>;
using CPtr = std::shared_ptr<const Sequence>;

// TODO: move to private section once Speculative decoding is fixed
explicit Sequence(const uint64_t id) : m_grouped_id(id) {}

static Sequence::Ptr create(const uint64_t id) {
return std::make_shared<Sequence>(id);
return Sequence::Ptr(new Sequence(id));
}

static Sequence::Ptr fork(Sequence::CPtr sequence, const uint64_t id) {
return std::make_shared<Sequence>(*sequence, id);
return Sequence::Ptr(new Sequence(*sequence, id));
}

bool operator ==(const Sequence& other) const {
Expand Down Expand Up @@ -182,14 +182,12 @@ class Sequence {
}

// Each KV block can be uniquely identified by
void set_sequence_group_weak_ptr(std::weak_ptr<SequenceGroup> sequence_group) {
void set_sequence_group_ptr(SequenceGroup* sequence_group) {
assert(sequence_group != nullptr);
m_sequence_group = sequence_group;
}

std::shared_ptr<SequenceGroup> get_sequence_group_ptr() const {
OPENVINO_ASSERT(!m_sequence_group.expired());
return m_sequence_group.lock();
}
std::shared_ptr<SequenceGroup> get_sequence_group_ptr() const;

// Each KV block can be uniquely identified by
// the tokens within the block and the tokens in the prefix before the block.
Expand All @@ -201,7 +199,7 @@ class Sequence {
// - each sequence shares the same prompt and KV-caches for promp
// - in case of beam search each sequence also shares specific part of generic phase
// via reference counter mechanism on BlockManager level
class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
uint64_t m_request_id;
std::vector<Sequence::Ptr> m_sequences;
ov::genai::GenerationConfig m_sampling_params;
Expand All @@ -228,7 +226,6 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {

size_t m_num_streamed_tokens = 0, m_stream_window_size = 0;


SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size)
: m_request_id(request_id),
m_sampling_params(sampling_params),
Expand All @@ -245,15 +242,16 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {

SequenceGroup(uint64_t request_id, const ov::Tensor input_ids, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size)
: SequenceGroup(request_id, sampling_params, block_size) {
add_sequence(Sequence::create(m_next_sequence_id++));

m_prompt_ids.resize(input_ids.get_size());
std::copy_n(input_ids.data<int64_t>(), input_ids.get_size(), m_prompt_ids.begin());
m_prompt_log_probs.reserve(m_prompt_ids.size());

// create a single sequence
add_sequence(Sequence::create(m_next_sequence_id++));
}

void add_sequence(const Sequence::Ptr & sequence) {
sequence->set_sequence_group_weak_ptr(shared_from_this());
sequence->set_sequence_group_ptr(this);
m_sequences.emplace_back(sequence);
}

Expand Down Expand Up @@ -323,7 +321,6 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return it != m_sequences.end();
}


/**
* @param seq_id Sequence identifier
* @return Pointer to the sequence with this ID.
Expand Down Expand Up @@ -410,7 +407,6 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
m_num_evicted_tokens += num_evicted_tokens;
}


/**
* Resets the eviction tracking on this sequence to the state prior to any eviction taking place.
*/
Expand All @@ -435,7 +431,6 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return get_num_processed_tokens() + get_num_scheduled_tokens();
}


bool requires_sampling() const {
return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.max_new_tokens > 0;
}
Expand Down Expand Up @@ -514,7 +509,6 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return (get_context_len() - get_num_evicted_tokens() + m_block_size - 1) / m_block_size;
}


// requires number of physical blocks for next generation
size_t get_num_blocks() const {
return get_num_logical_blocks();
Expand All @@ -526,7 +520,6 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {

Sequence::Ptr fork_sequence(Sequence::CPtr sequence) {
auto forked_sequence = Sequence::fork(sequence, m_next_sequence_id++);
forked_sequence->set_sequence_group_weak_ptr(sequence->get_sequence_group_ptr());
m_sequences.emplace_back(forked_sequence);
return forked_sequence;
}
Expand Down Expand Up @@ -679,4 +672,10 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
m_generation_stream->push(std::move(outputs));
}
};

inline std::shared_ptr<SequenceGroup> Sequence::get_sequence_group_ptr() const {
OPENVINO_ASSERT(m_sequence_group != nullptr);
return m_sequence_group->shared_from_this();
}

}
11 changes: 7 additions & 4 deletions tests/cpp/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ void clear_finished_sequences(std::vector<SequenceGroup::Ptr>& requests) {
});
requests.erase(new_end, requests.end());
}
std::shared_ptr<ov::Model> get_model(size_t num_layers) {
std::shared_ptr<ov::Model> get_model(ov::Core core, size_t num_layers) {
ov::NodeVector keys;
ov::NodeVector values;
ov::ParameterVector params;
ov::element::Type inference_precision = core.get_property("CPU", ov::hint::inference_precision);

auto shape = ov::PartialShape({ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic()});
for (size_t i = 0; i < num_layers; i++) {
auto key = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape);
auto value = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape);
auto key = std::make_shared<ov::op::v0::Parameter>(inference_precision, shape);
auto value = std::make_shared<ov::op::v0::Parameter>(inference_precision, shape);
key->get_output_tensor(0).set_names({"key_cache." + std::to_string(i)});
value->get_output_tensor(0).set_names({"value_cache." + std::to_string(i)});
keys.push_back(key);
Expand All @@ -42,7 +44,7 @@ std::shared_ptr<ov::Model> get_model(size_t num_layers) {
std::shared_ptr<CacheManager> init_cache_manager(SchedulerConfig scheduler_config) {
ov::Core core = ov::Core();
size_t num_decoder_layers = 12;
ov::InferRequest request = core.compile_model(get_model(num_decoder_layers)).create_infer_request();
ov::InferRequest request = core.compile_model(get_model(core, num_decoder_layers)).create_infer_request();
size_t head_size = 64, head_size_u8 = head_size + 8;
std::vector<size_t> num_kv_heads(12, 12);
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
Expand Down Expand Up @@ -326,6 +328,7 @@ TEST(TestScheduler, test_partial_preemption_beam_search) {
SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(0, ov::Tensor(ov::element::i64, {tokens.size()}, tokens.data()),
ov::genai::beam_search(), 4);
std::vector<SequenceGroup::Ptr> requests = {sequence_group};
EXPECT_NO_THROW(requests[0]->get_running_sequences()[0]->get_sequence_group_ptr());

Scheduler scheduler = Scheduler(4, init_cache_manager(scheduler_config), scheduler_config);
auto out = scheduler.schedule(requests);
Expand Down

0 comments on commit aeb7dd2

Please sign in to comment.