Skip to content

Commit

Permalink
Fix #898, stopping criteria update. Remove any from logic after testi…
Browse files Browse the repository at this point in the history
…ng batching, all() should be used. Fix device mismatch for stopping criteria token.
  • Loading branch information
rwightman committed Jul 2, 2024
1 parent 1be2c89 commit 1598cd1
Showing 1 changed file with 5 additions and 21 deletions.
26 changes: 5 additions & 21 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,26 +215,22 @@ def generate(
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
device = image.device

with torch.no_grad():
sot_token_id = 49406 if sot_token_id is None else sot_token_id
eos_token_id = 49407 if eos_token_id is None else eos_token_id
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
MinLengthLogitsProcessor(min_seq_len, eos_token_id, device=device),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)

if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]

stopping_criteria = StoppingCriteriaList(
stopping_criteria
)

device = image.device
stopping_criteria = StoppingCriteriaList(stopping_criteria)

if generation_type == "beam_search":
output = self._generate_beamsearch(
Expand Down Expand Up @@ -313,12 +309,7 @@ def generate(

cur_len += 1

is_done = False
if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria:
is_done = stopping_criteria(out, None).all()
else:
is_done = stopping_criteria(out, None).any()
if is_done:
if all(stopping_criteria(out, None)):
break

if num_dims == 1:
Expand Down Expand Up @@ -460,14 +451,7 @@ def _generate_beamsearch(

# increase cur_len
cur_len = cur_len + 1
is_done = False
if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria:
is_done = stopping_criteria(input_ids, None).all()
else:
is_done = stopping_criteria(input_ids, None).any()
if is_done:
break
if beam_scorer.is_done or is_done:
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
break

final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
Expand Down

0 comments on commit 1598cd1

Please sign in to comment.