Skip to content

Commit

Permalink
MPS workaround for inf values stemming from pytorch/pytorch#84364
Browse files Browse the repository at this point in the history
  • Loading branch information
Birch-san committed Nov 5, 2022
1 parent c56a015 commit fe96152
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion k_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def append_dims(x, target_dims):
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
expanded = x[(...,) + (None,) * dims_to_append]
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
# https://github.com/pytorch/pytorch/issues/84364
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded


def n_params(module):
Expand Down

0 comments on commit fe96152

Please sign in to comment.