diff --git a/Dockerfile b/Dockerfile index b73d907b87..665b23427d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM ubuntu:22.04 ARG JOBS WORKDIR /workspace -RUN apt-get update -y && apt-get install -y python3-pip python3-venv git +RUN apt-get update -y && apt-get install -y --no-install-recommends python3-pip python3-venv git # Install OpenVINO RUN git clone --branch master https://github.com/openvinotoolkit/openvino.git && \ @@ -25,7 +25,7 @@ ENV OpenVINO_DIR=/workspace/openvino_build RUN wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json # Build GenAI library with dependencies -RUN git clone https://github.com/Wovchena/openvino.genai-public.git -b reuse-Tokenizer openvino.genai && \ +RUN git clone https://github.com/openvinotoolkit/openvino.genai.git && \ cd /workspace/openvino.genai/thirdparty && git submodule update --remote --init && \ mkdir /workspace/openvino.genai/build && cd /workspace/openvino.genai/build && \ cmake -DCMAKE_BUILD_TYPE=Release .. && \ @@ -33,6 +33,6 @@ RUN git clone https://github.com/Wovchena/openvino.genai-public.git -b reuse-Tok # Install test dependencies RUN python3 -m pip install --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly/ /workspace/openvino.genai/thirdparty/openvino_tokenizers -RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/openvino.genai/tests/python_tests/continuous_batching/requirements.txt +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/openvino.genai/tests/python_tests/requirements.txt ENV PYTHONPATH=/workspace/openvino.genai/build/ ENV LD_LIBRARY_PATH=/workspace/openvino.genai/build/ diff --git a/src/cpp/src/logit_processor.hpp b/src/cpp/src/logit_processor.hpp index cb3ffb37c0..2131ec01cb 100644 --- a/src/cpp/src/logit_processor.hpp +++ b/src/cpp/src/logit_processor.hpp @@ -16,12 +16,35 @@ struct Token { Token() = default; }; +struct Logits { + float * m_data = nullptr; + size_t m_size; + // Late initialized + std::vector m_vector; + + Logits(float* data, size_t size): m_data(data), m_size(size) {} + + + void initialize_vector() { + OPENVINO_ASSERT(m_vector.size() == 0, "Logits vector already initialized"); + m_vector.reserve(m_size); + for (size_t i = 0; i < m_size; i++) + m_vector.emplace_back(m_data[i], i); + + } + + void resize(size_t new_size) { + m_size = new_size; + m_vector.resize(new_size); + } +}; + namespace LogitTransformers { using TokenIds = std::vector; class ILogitTransformer { public: - virtual void apply(std::vector& logits) = 0; + virtual void apply(Logits& logits) = 0; virtual bool is_applicable(size_t generated_tokens_cnt = 0) { return true; @@ -32,11 +55,15 @@ class TopPFilter : public ILogitTransformer { public: TopPFilter(double top_p) : m_top_p(top_p) {} - void apply(std::vector& logits) override { - std::sort(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); + void apply(Logits& logits) override { + if (logits.m_vector.size() == 0) { + // Initialize and sort vector + logits.initialize_vector(); + std::sort(logits.m_vector.begin(), logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); + } float probability_sum = 0.0f; size_t nucleus_size = 0; - for (const auto& probability : logits) { + for (const auto& probability : logits.m_vector) { probability_sum += probability.m_log_prob; nucleus_size += 1; if (probability_sum > m_top_p) break; @@ -52,10 +79,17 @@ class TopKFilter : public ILogitTransformer { public: TopKFilter(size_t top_k) : m_top_k(top_k) {} - void apply(std::vector& logits) override { - std::sort(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); - size_t top_k = logits.size() >= m_top_k ? m_top_k : logits.size(); - logits.resize(top_k); + // If this transform is used along with top_p, it should be applied after it since top_p sorts entire vector and top_k does it only partially + void apply(Logits& logits) override { + if (m_top_k >= logits.m_size) + return; + + if (logits.m_vector.size() == 0) { + // Initialize and partially sort vector + logits.initialize_vector(); + std::partial_sort(logits.m_vector.begin(), logits.m_vector.begin() + m_top_k, logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); + } + logits.resize(m_top_k); } protected: @@ -66,18 +100,23 @@ class TemperatureLogitTransform : public ILogitTransformer { public: TemperatureLogitTransform(double temperature) : m_temperature(temperature) {}; - void apply(std::vector& logits) override { - auto max_prob_token = std::max_element(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) { return lhs.m_log_prob < rhs.m_log_prob; }); - float max_logit = max_prob_token->m_log_prob; - - std::for_each(logits.begin(), logits.end(), [max_logit, this](Token& val) {val.m_log_prob = expf((val.m_log_prob - max_logit) / this->m_temperature);}); + void apply(Logits& logits) override { + float max_logit = -std::numeric_limits::infinity(); + for (size_t i = 0; i < logits.m_size; i++) { + if (logits.m_data[i] > max_logit) { + max_logit = logits.m_data[i]; + } + } float norm_sum = 0.0; - for (const auto& val : logits) { - norm_sum += val.m_log_prob; + for (size_t i = 0; i < logits.m_size; i++) { + logits.m_data[i] = expf((logits.m_data[i] - max_logit) / this->m_temperature); + norm_sum += logits.m_data[i]; } - std::for_each(logits.begin(), logits.end(), [norm_sum](Token& val) {val.m_log_prob /= norm_sum;}); + for (size_t i = 0; i < logits.m_size; i++) { + logits.m_data[i] /= norm_sum; + } } protected: @@ -118,32 +157,28 @@ class RepetitionPenaltyTransform : public IPenaltyTransformer { m_penalty = repetition_penalty; }; - void apply(std::vector& logits) override { - size_t vocab_size = logits.size(); + void apply(Logits& logits) override { + size_t vocab_size = logits.m_size; for (const auto& prompt_id : *m_unique_prompt_token_ids) { OPENVINO_ASSERT((prompt_id >= 0) && (prompt_id < vocab_size), "input_ids token out of bounds"); - OPENVINO_ASSERT(logits[prompt_id].m_index == prompt_id, "input_logits must have original index order"); - auto logit_value = logits[prompt_id].m_log_prob; - if (logit_value >= 0) { - logits[prompt_id].m_log_prob /= m_penalty; + if (logits.m_data[prompt_id] >= 0) { + logits.m_data[prompt_id] /= m_penalty; } else { - logits[prompt_id].m_log_prob *= m_penalty; + logits.m_data[prompt_id] *= m_penalty; }; } for (const auto& input_id_pair : *m_unique_generated_token_ids) { const auto& input_id = input_id_pair.first; OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds"); - OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order"); - auto logit_value = logits[input_id].m_log_prob; - if (logit_value >= 0) { - logits[input_id].m_log_prob /= m_penalty; + if (logits.m_data[input_id] >= 0) { + logits.m_data[input_id] /= m_penalty; } else { - logits[input_id].m_log_prob *= m_penalty; + logits.m_data[input_id] *= m_penalty; }; } } - void apply(std::vector& logits, const TokenIds& input_ids) { + void apply(Logits& logits, const TokenIds& input_ids) { set_unique_prompt_token_ids(nullptr); extract_generated_tokens(input_ids); apply(logits); @@ -166,10 +201,10 @@ class EOSPenaltyTransform : public ILogitTransformer { EOSPenaltyTransform(size_t eos_token_id, size_t min_generated_tokens) : m_eos_token_id(eos_token_id), m_applicable_tensor_len(min_generated_tokens) {} - void apply(std::vector& logits) override { - // Since EOS penalty is applied early, the token vector is not sorted + void apply(Logits& logits) override { + // Since EOS penalty is applied early, the token vector is not initialized yet // and we can assume element order match token ids. - logits[m_eos_token_id].m_log_prob = 0.f; + logits.m_data[m_eos_token_id] = 0.f; } @@ -188,22 +223,20 @@ class FrequencyPenaltyTransform : public IPenaltyTransformer { m_penalty = value; }; - void apply(std::vector& logits) override { - size_t vocab_size = logits.size(); + void apply(Logits& logits) override { + size_t vocab_size = logits.m_size; for (const auto& input_id_pair : *m_unique_generated_token_ids) { const auto& input_id = input_id_pair.first; OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds"); - OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order"); - auto logit_value = logits[input_id].m_log_prob; - if (logit_value >= 0) { - logits[input_id].m_log_prob -= m_penalty * input_id_pair.second; + if (logits.m_data[input_id] >= 0) { + logits.m_data[input_id] -= m_penalty * input_id_pair.second; } else { - logits[input_id].m_log_prob += m_penalty * input_id_pair.second; + logits.m_data[input_id] += m_penalty * input_id_pair.second; }; } } - void apply(std::vector& logits, const TokenIds& input_ids) { + void apply(Logits& logits, const TokenIds& input_ids) { extract_generated_tokens(input_ids); apply(logits); } @@ -215,22 +248,20 @@ class PresencePenaltyTransform : public IPenaltyTransformer { m_penalty = value; }; - void apply(std::vector& logits) override { - size_t vocab_size = logits.size(); + void apply(Logits& logits) override { + size_t vocab_size = logits.m_size; for (const auto& input_id_pair : *m_unique_generated_token_ids) { const auto& input_id = input_id_pair.first; OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds"); - OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order"); - auto logit_value = logits[input_id].m_log_prob; - if (logit_value >= 0) { - logits[input_id].m_log_prob -= m_penalty; + if (logits.m_data[input_id] >= 0) { + logits.m_data[input_id] -= m_penalty; } else { - logits[input_id].m_log_prob += m_penalty; + logits.m_data[input_id] += m_penalty; }; } } - void apply(std::vector& logits, const TokenIds& input_ids) { + void apply(Logits& logits, const TokenIds& input_ids) { extract_generated_tokens(input_ids); apply(logits); } @@ -286,14 +317,14 @@ class LogitProcessor { if (sampling_params.top_p != 1.0f) { m_logit_transformers.emplace_back(new LogitTransformers::TopPFilter(sampling_params.top_p)); } - if (sampling_params.top_k > 0) { + if (sampling_params.top_k > 0 && sampling_params.top_k < std::numeric_limits::max()) { m_logit_transformers.emplace_back(new LogitTransformers::TopKFilter(sampling_params.top_k)); } } } } - void apply(std::vector& logits) { + void apply(Logits& logits) { for (const auto& transformer : m_logit_transformers) { if (transformer->is_applicable(m_generated_tokens)) { transformer->apply(logits); diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 6390fc8725..4c7ea52bed 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -18,6 +18,7 @@ #include "logit_processor.hpp" #include "scheduler.hpp" #include "sequence_group.hpp" +#include "timer.hpp" namespace ov::genai { // Modifyed Knuth–Morris–Pratt algorithm which returns tokens following after every needle occurance in haystack @@ -203,40 +204,49 @@ class GroupBeamSearcher { class Sampler { - std::vector _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) { + Logits _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) { ov::Shape logits_shape = logits.get_shape(); size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2]; OPENVINO_ASSERT(batch_idx <= batch_size); size_t batch_offset = batch_idx * seq_len * vocab_size; size_t sequence_offset = (seq_len - 1) * vocab_size; - const float* logits_data = logits.data() + batch_offset + sequence_offset; + float* logits_data = logits.data() + batch_offset + sequence_offset; - std::vector logit_vector(vocab_size); - for (size_t i = 0; i < logit_vector.size(); i++) { - logit_vector[i] = Token(logits_data[i], i); - } - return logit_vector; + return Logits{logits_data, vocab_size}; } - Token _greedy_sample(const std::vector& logit_vector) const { - Token max_token{-std::numeric_limits::infinity() , 0}; - for (const auto& logit : logit_vector) { - if (logit.m_log_prob > max_token.m_log_prob) { - max_token = logit; + Token _greedy_sample(const Logits& logits) const { + // For greedy sampling we do not expect sorting or shrinking considered tokens + // so we can operate directly on the data buffer + float max_value = -std::numeric_limits::infinity(); + size_t max_index = 0; + for (size_t i = 0; i < logits.m_size; ++i) { + if (logits.m_data[i] > max_value) { + max_value = logits.m_data[i]; + max_index = i; } } - return max_token; + return Token(logits.m_data[max_index], max_index); } - std::vector _multinomial_sample(const std::vector& logit_vector, size_t num_tokens_per_sequence) { - std::vector multinomial_weights(logit_vector.size()); - for (size_t i = 0; i < logit_vector.size(); i++) multinomial_weights[i] = logit_vector[i].m_log_prob; + std::vector _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence) { + // If top_p or top_k was applied we use sorted vector, if not we go with original buffer. + std::vector multinomial_weights; + multinomial_weights.reserve(logits.m_size); + if (logits.m_vector.size() > 0) + for (auto& logit: logits.m_vector) multinomial_weights.emplace_back(logit.m_log_prob); + else + multinomial_weights.assign(logits.m_data, logits.m_data + logits.m_size); auto dist = std::discrete_distribution(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1 + std::vector out_tokens; for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) { size_t element_to_pick = dist(rng_engine); - out_tokens.push_back(logit_vector[element_to_pick]); + if (logits.m_vector.size() > 0) + out_tokens.push_back(logits.m_vector[element_to_pick]); + else + out_tokens.emplace_back(logits.m_data[element_to_pick], element_to_pick); } return out_tokens; } @@ -294,17 +304,29 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, running_sequence->append_token(sampled_token_id.m_index, sampled_token_id.m_log_prob); }; for (size_t running_sequence_id = 0; running_sequence_id < num_running_sequences; ++running_sequence_id) { + static ManualTimer timer1("sample::_get_logit_vector"); + timer1.start(); auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id); + timer1.end(); + static ManualTimer timer2("sample::logit_processor.apply"); + timer2.start(); logit_processor.apply(logit_vector); + timer2.end(); Token sampled_token_id; if (sampling_params.is_greedy_decoding()) { + static ManualTimer timer("sample::_greedy_sample"); + timer.start(); sampled_token_id = _greedy_sample(logit_vector); + timer.end(); } else { // is_multinomial() const bool is_generate_n_tokens = sequence_group->num_total_seqs() == 1; const size_t num_tokens_per_sequence = is_generate_n_tokens ? sampling_params.num_return_sequences : 1; + static ManualTimer timer("sample::_multinomial_sample"); + timer.start(); auto sampled_token_ids = _multinomial_sample(logit_vector, num_tokens_per_sequence); + timer.end(); sampled_token_id = sampled_token_ids[0]; if (is_generate_n_tokens) { diff --git a/tests/python_tests/test_preemption.py b/tests/python_tests/test_preemption.py index 8c9bda1d33..ae6830d768 100644 --- a/tests/python_tests/test_preemption.py +++ b/tests/python_tests/test_preemption.py @@ -53,7 +53,7 @@ def test_preemption(tmp_path, params): ref_texts=get_current_plarform_ref_texts({ "linux": [ [ - "\n\nOpenVINO is a live platform that allows users to create and manage a new library for open source applications.\n\nOpenVINO is" + "\n\nOpenVINO is a programming language with a lot of benefits. It has been designed in such a way that it is probably not suitable for" ], [ " You're getting much better results from doing this, than you are by not doing this. I have a BH and I was so far" @@ -109,12 +109,12 @@ def test_preemption_with_multinomial(tmp_path, dynamic_split_fuse): ref_texts=get_current_plarform_ref_texts({ "linux": [ [ - "\nI've seen this expression used too many times without making sense.\nAs an AI engineer, and as a scientist, we should make everything easier" + " Buzzfeed ESPN CNBC MSNBC CBS\nFox News is on top of the list.\nIf a news station tries to afford real estate" ], [ - " position of the Z-shaped groove?\n0.41\nWhat is the current position of the Z-shaped groove?\n0.11\n", - " status of all of this? I can't stop thinking about it.\nIt's been a while since I've seen it. I found it a", - " status of your blog? Do you accept feedback?\nYes, I’m happy to accept feedback at this time (I’m a" + " condition of the leg?\nIt's been quite a while since I've seen it, so I didn't really know if it was good or bad", + ' ratio of (-9)/(-12)*(-128)/(-32)?\n-1/2\nEvaluate (1*-9)/((', + ' ratio of (-4 + (-5)/(-5))*-3?\n-6\nEvaluate ((-108)/(-32))/(' ], [ "\nIt's in the middle of nowhere if you haven’t seen one yet! It might be more convenient there than anywhere else.. maybe take", diff --git a/tests/python_tests/test_sampling.py b/tests/python_tests/test_sampling.py index f9b478bd14..aa1a473cff 100644 --- a/tests/python_tests/test_sampling.py +++ b/tests/python_tests/test_sampling.py @@ -124,7 +124,7 @@ class RandomSamplingTestStruct: prompts=["What is OpenVINO?"], ref_texts=[ [ - "\n\nOpenVINO is a software development platform developed by OpenVINO, a set of technology companies and startups that enables developers to use the most" + "\n\nOpenVINO is a new open source virtual cold storage solution for virtual machines under the Ubuntu Foundation. OpenVINO is a virtualized virtual" ] ], ), @@ -174,7 +174,7 @@ class RandomSamplingTestStruct: prompts=["What is OpenVINO?"], ref_texts=[ [ - "\nOpen Vino's are a new and improved way to find cheap, fast-investment frozen vegetables that have no waste or calories. They're" + '\nOpen Vino (OLIN) was launched on April 12, 2016 and has resulted in around 15% of all virtual meetings being hosted by companies' ] ], ), @@ -183,9 +183,9 @@ class RandomSamplingTestStruct: prompts=["What is location of"], ref_texts=[ [ - " the exact same image?\nI've tried multiple times to find it, but I'm still not sure. I am sure it's the exact same", - " your new house?\nAnywhere that has a GPS. It will be up to you.", - " your cat? He is more likely to be on the floor with him.\nTalduck" + " the sensor?\nIt's a sensor on the back of the phone.\nGotcha, very cool.\nGood job man!", + ' this website?\n\nTasty Big Fish, New York, NY\n\nFounded in 2018 by award-winning authors, Including the creative minds', + " this?\nIt's actually in this sub." ] ], ), @@ -216,7 +216,7 @@ class RandomSamplingTestStruct: prompts=["What is OpenVINO?"], ref_texts=[ [ - "\n\nOpenVINO is a software development platform developed by OpenVINO, Inc., which uses a RESTful API for server-side web applications" + '\n\nOpenVINO is a new open source virtual application that lets you create and modify all kinds of virtual machines in your environment. OpenVINO' ] ], ), @@ -225,7 +225,7 @@ class RandomSamplingTestStruct: prompts=["What is OpenVINO?"], ref_texts=[ [ - "\n\nOpenVINO is a software development platform developed by OpenVINO, Inc., which offers the Linux-based platform. OpenVINO's" + '\nOpenVINO is a technology for low-power video streaming by building high-efficiency, decoupling and shrinking cards. This strategy is important' ] ], ),