Skip to content

Commit

Permalink
fix the unit tests expectations
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Jirmasek committed Dec 18, 2024
1 parent 5fdd5e3 commit 62daa78
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions coremltools/converters/mil/mil/passes/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7471,8 +7471,8 @@ def verify_sdpa_outputs(self, example_inputs: Dict[str, torch.Tensor]):

assert ops_counts[0] == 1 or ops_counts[0] == 3 # (attn_mask might be cast to bool from input fp16 dtype)
assert ops_counts[1] == 1 or ops_counts[1] == 3 # the Q seq length is less than the default min seq length
assert ops_counts[2] >= 26 * 16
assert ops_counts[3] >= 26 * 32
assert ops_counts[2] >= 11 * 16 # 11 ops (without consts) per slice
assert ops_counts[3] >= 11 * 32

predict_inputs = copy.deepcopy(example_inputs)
if "attn_mask" in predict_inputs:
Expand Down

0 comments on commit 62daa78

Please sign in to comment.