Skip to content

Commit

Permalink
is_vector_initialized
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Jul 24, 2024
1 parent 6d8117e commit f4af2b6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/cpp/src/logit_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct Logits {
m_vector.emplace_back(m_data[i], i);
}

bool vector_initialized() const {
bool is_vector_initialized() const {
return m_vector.size() > 0;
}

Expand All @@ -59,7 +59,7 @@ class TopPFilter : public ILogitTransformer {
TopPFilter(double top_p) : m_top_p(top_p) {}

void apply(Logits& logits) override {
if (!logits.vector_initialized()) {
if (!logits.is_vector_initialized()) {
// Initialize and sort vector
logits.initialize_vector();
std::sort(logits.m_vector.begin(), logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
Expand Down Expand Up @@ -92,7 +92,7 @@ class TopKFilter : public ILogitTransformer {
return;
*/

if (!logits.vector_initialized()) {
if (!logits.is_vector_initialized()) {
// Initialize and partially sort vector
logits.initialize_vector();
// TODO: Uncommenting below requires uncommenting section above
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class Sampler {
// If top_p or top_k was applied we use sorted vector, if not we go with original buffer.
std::vector<float> multinomial_weights;
multinomial_weights.reserve(logits.m_size);
if (logits.vector_initialized())
if (logits.is_vector_initialized())
for (auto& logit: logits.m_vector) multinomial_weights.emplace_back(logit.m_log_prob);
else
multinomial_weights.assign(logits.m_data, logits.m_data + logits.m_size);
Expand All @@ -242,7 +242,7 @@ class Sampler {
std::vector<Token> out_tokens;
for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) {
size_t element_to_pick = dist(rng_engine);
if (logits.vector_initialized())
if (logits.is_vector_initialized())
out_tokens.push_back(logits.m_vector[element_to_pick]);
else
out_tokens.emplace_back(logits.m_data[element_to_pick], element_to_pick);
Expand Down
16 changes: 8 additions & 8 deletions tests/cpp/logit_filtering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ TEST_P(TemperatureTransformTest, TransformResultEqualToReference) {
auto logits = Logits(test_struct.input, TemperatureTransformTestStruct::size);
auto transform = TemperatureLogitTransform(test_struct.temperature);
transform.apply(logits);
ASSERT_FALSE(logits.vector_initialized());
ASSERT_FALSE(logits.is_vector_initialized());
ASSERT_EQ(logits.m_size, TemperatureTransformTestStruct::size); // temperature transfrom should not change buffer size
for (size_t i = 0; i < logits.m_size; i++) {
EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6);
Expand Down Expand Up @@ -58,7 +58,7 @@ TEST_P(TopPFilteringTest, FilterResultEqualToReference) {
auto logits = Logits(test_struct.input, TopPTestStruct::size);
auto transform = TopPFilter(test_struct.top_p);
transform.apply(logits);
ASSERT_TRUE(logits.vector_initialized());
ASSERT_TRUE(logits.is_vector_initialized());
ASSERT_EQ(logits.m_size, logits.m_vector.size());
ASSERT_EQ(logits.m_size, test_struct.expected_output.size());
for (size_t i = 0; i < logits.m_vector.size(); i++) {
Expand Down Expand Up @@ -94,7 +94,7 @@ TEST_P(TopKFilteringTest, FilterResultEqualToReference) {
auto logits = Logits(test_struct.input, TopKTestStruct::size);
auto transform = TopKFilter(test_struct.top_k);
transform.apply(logits);
ASSERT_TRUE(logits.vector_initialized());
ASSERT_TRUE(logits.is_vector_initialized());
ASSERT_EQ(logits.m_size, logits.m_vector.size());
ASSERT_EQ(logits.m_size, test_struct.expected_output.size());
for (size_t i = 0; i < logits.m_vector.size(); i++) {
Expand Down Expand Up @@ -123,7 +123,7 @@ TEST(TopKFilteringTest, FilterNotAppliedTopKGreaterThanInputSize) {
auto logits = Logits(input, 3);
auto transform = TopKFilter(top_k);
transform.apply(logits);
ASSERT_FALSE(logits.vector_initialized());
ASSERT_FALSE(logits.is_vector_initialized());
ASSERT_EQ(logits.m_size, 3);
for (size_t i = 0; i < logits.m_size; i++) {
EXPECT_EQ(logits.m_data[i], expected_output[i]);
Expand All @@ -147,7 +147,7 @@ TEST_P(RepetitionPenaltyTransformTest, TransformResultEqualToReference) {
auto logits = Logits(test_struct.input, RepetitionPenaltyTransformTestStruct::size);
auto transform = RepetitionPenaltyTransform(test_struct.penalty);
transform.apply(logits, test_struct.input_ids);
ASSERT_FALSE(logits.vector_initialized());
ASSERT_FALSE(logits.is_vector_initialized());
ASSERT_EQ(logits.m_size, RepetitionPenaltyTransformTestStruct::size); // penalty transfrom should not change buffer size
for (size_t i = 0; i < logits.m_size; i++) {
EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6);
Expand Down Expand Up @@ -206,7 +206,7 @@ TEST_P(FrequencyPenaltyTransformTest, TransformResultEqualToReference) {
auto logits = Logits(test_struct.input, FrequencyPenaltyTransformTestStruct::size);
auto transform = FrequencyPenaltyTransform(test_struct.penalty);
transform.apply(logits, test_struct.input_ids);
ASSERT_FALSE(logits.vector_initialized());
ASSERT_FALSE(logits.is_vector_initialized());
ASSERT_EQ(logits.m_size, FrequencyPenaltyTransformTestStruct::size); // penalty transfrom should not change buffer size
for (size_t i = 0; i < logits.m_size; i++) {
EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6);
Expand Down Expand Up @@ -265,7 +265,7 @@ TEST_P(PresencePenaltyTransformTest, TransformResultEqualToReference) {
auto logits = Logits(test_struct.input, PresencePenaltyTransformTestStruct::size);
auto transform = PresencePenaltyTransform(test_struct.penalty);
transform.apply(logits, test_struct.input_ids);
ASSERT_FALSE(logits.vector_initialized());
ASSERT_FALSE(logits.is_vector_initialized());
ASSERT_EQ(logits.m_size, PresencePenaltyTransformTestStruct::size); // penalty transfrom should not change buffer size
for (size_t i = 0; i < logits.m_size; i++) {
EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6);
Expand Down Expand Up @@ -322,7 +322,7 @@ TEST_P(EOSPenaltyTransformTest, TransformResultEqualToReference) {
auto logits = Logits(test_struct.input, EOSPenaltyTransformTestStruct::size);
auto transform = EOSPenaltyTransform(test_struct.eos_token_id, std::numeric_limits<size_t>::max());
transform.apply(logits);
ASSERT_FALSE(logits.vector_initialized());
ASSERT_FALSE(logits.is_vector_initialized());
ASSERT_EQ(logits.m_size, EOSPenaltyTransformTestStruct::size); // penalty transfrom should not change buffer size
for (size_t i = 0; i < logits.m_size; i++) {
EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6);
Expand Down

0 comments on commit f4af2b6

Please sign in to comment.