From 1598cd1bfe239774ec4724da5f4ff2843b43a02a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 2 Jul 2024 14:46:49 -0700 Subject: [PATCH] Fix #898, stopping criteria update. Remove any from logic after testing batching, all() should be used. Fix device mismatch for stopping criteria token. --- src/open_clip/coca_model.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 8eeaf6b90..618614e96 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -215,6 +215,7 @@ def generate( # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + device = image.device with torch.no_grad(): sot_token_id = 49406 if sot_token_id is None else sot_token_id @@ -222,19 +223,14 @@ def generate( pad_token_id = self.pad_id if pad_token_id is None else pad_token_id logit_processor = LogitsProcessorList( [ - MinLengthLogitsProcessor(min_seq_len, eos_token_id), + MinLengthLogitsProcessor(min_seq_len, eos_token_id, device=device), RepetitionPenaltyLogitsProcessor(repetition_penalty), ] ) if stopping_criteria is None: stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] - - stopping_criteria = StoppingCriteriaList( - stopping_criteria - ) - - device = image.device + stopping_criteria = StoppingCriteriaList(stopping_criteria) if generation_type == "beam_search": output = self._generate_beamsearch( @@ -313,12 +309,7 @@ def generate( cur_len += 1 - is_done = False - if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria: - is_done = stopping_criteria(out, None).all() - else: - is_done = stopping_criteria(out, None).any() - if is_done: + if all(stopping_criteria(out, None)): break if num_dims == 1: @@ -460,14 +451,7 @@ def _generate_beamsearch( # increase cur_len cur_len = cur_len + 1 - is_done = False - if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria: - is_done = stopping_criteria(input_ids, None).all() - else: - is_done = stopping_criteria(input_ids, None).any() - if is_done: - break - if beam_scorer.is_done or is_done: + if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): break final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None