diff --git a/.github/workflows/package-release.yml b/.github/workflows/package-release.yml index 504310a9..e725ef65 100644 --- a/.github/workflows/package-release.yml +++ b/.github/workflows/package-release.yml @@ -24,8 +24,6 @@ jobs: filename_suffix: '' - python: '3.11' filename_suffix: '-4-1' - exclude: - - { platform: { requirements: win-dml.txt }, version: { python: '3.11' } } runs-on: ${{ matrix.platform.os }} steps: - name: Checkout repository diff --git a/generator_process/directml_patches.py b/generator_process/directml_patches.py index 6c0e2110..e8afd63f 100644 --- a/generator_process/directml_patches.py +++ b/generator_process/directml_patches.py @@ -7,16 +7,6 @@ active_dml_patches: list | None = None -def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None, pre_patch): - if input.device.type == "dml" and beta == 0: - if out is not None: - torch.bmm(batch1, batch2, out=out) - out *= alpha - return out - return alpha * (batch1 @ batch2) - return pre_patch(input, batch1, batch2, beta=beta, alpha=alpha, out=out) - - def pad(input, pad, mode="constant", value=None, *, pre_patch): if input.device.type == "dml" and mode == "constant": pad_dims = torch.tensor(pad, dtype=torch.int32).view(-1, 2).flip(0) @@ -39,10 +29,10 @@ def pad(input, pad, mode="constant", value=None, *, pre_patch): return pre_patch(input, pad, mode=mode, value=value) -def getitem(self, key, *, pre_patch): - if isinstance(key, Tensor) and "dml" in [self.device.type, key.device.type] and key.numel() == 1: - return pre_patch(self, int(key)) - return pre_patch(self, key) +def layer_norm(input, normalized_shape, weight = None, bias = None, eps = 1e-05, *, pre_patch): + if input.device.type == "dml": + return pre_patch(input.contiguous(), normalized_shape, weight, bias, eps) + return pre_patch(input, normalized_shape, weight, bias, eps) def retry_OOM(module): @@ -110,17 +100,9 @@ def dml_patch_method(object, name, patched): setattr(object, name, functools.partialmethod(patched, pre_patch=original)) active_dml_patches.append({"object": object, "name": name, "original": original}) - # Not all places where the patches have an effect are necessarily listed. - - # diffusers.models.attention_processor.Attention.get_attention_scores() - # diffusers.models.attention.AttentionBlock.forward() - # Diffusers implementation gives torch.empty() tensors with beta=0 to baddbmm(), which may contain NaNs. - # DML implementation doesn't properly ignore input argument with beta=0 and causes NaN propagation. - dml_patch(torch, "baddbmm", baddbmm) - dml_patch(torch.nn.functional, "pad", pad) - # DDIMScheduler.step(), PNDMScheduler.step(), No error messages or crashes, just may randomly freeze. - dml_patch_method(Tensor, "__getitem__", getitem) + + dml_patch(torch.nn.functional, "layer_norm", layer_norm) def decorate_forward(name, module): """Helper function to better find which modules DML fails in as it often does