diff --git a/tests/test_speculative_generation.py b/tests/test_speculative_generation.py index 52044e5e..2834410d 100644 --- a/tests/test_speculative_generation.py +++ b/tests/test_speculative_generation.py @@ -87,7 +87,7 @@ def test_speculative_greedy_generation(tokenizer, model, model2, ref_model, max_ new_tokens[:, random_pos] = random.randrange(1, 100) combined_ids = torch.cat((generated_ids, new_tokens), dim=1) - logits = model(combined_ids, start_from_position=1).logits + logits = model(combined_ids).logits # Найти первую позицию, где токены совпали match_length = 0