From fe96152b13c6816378a7890c67673f3ea90ff130 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 3 Sep 2022 15:54:11 +0100 Subject: [PATCH] MPS workaround for inf values stemming from https://github.com/pytorch/pytorch/issues/84364 --- k_diffusion/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 9afedb99..ce6014be 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -42,7 +42,10 @@ def append_dims(x, target_dims): dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] + expanded = x[(...,) + (None,) * dims_to_append] + # MPS will get inf values if it tries to index into the new axes, but detaching fixes this. + # https://github.com/pytorch/pytorch/issues/84364 + return expanded.detach().clone() if expanded.device.type == 'mps' else expanded def n_params(module):