Skip to content

Commit

Permalink
Merge pull request #799 from carson-katri/directml
Browse files Browse the repository at this point in the history
Update DirectML
  • Loading branch information
NullSenseStudio authored Jun 1, 2024
2 parents c1575e2 + eb53c78 commit db13378
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 26 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/package-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ jobs:
filename_suffix: ''
- python: '3.11'
filename_suffix: '-4-1'
exclude:
- { platform: { requirements: win-dml.txt }, version: { python: '3.11' } }
runs-on: ${{ matrix.platform.os }}
steps:
- name: Checkout repository
Expand Down
30 changes: 6 additions & 24 deletions generator_process/directml_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@
active_dml_patches: list | None = None


def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None, pre_patch):
if input.device.type == "dml" and beta == 0:
if out is not None:
torch.bmm(batch1, batch2, out=out)
out *= alpha
return out
return alpha * (batch1 @ batch2)
return pre_patch(input, batch1, batch2, beta=beta, alpha=alpha, out=out)


def pad(input, pad, mode="constant", value=None, *, pre_patch):
if input.device.type == "dml" and mode == "constant":
pad_dims = torch.tensor(pad, dtype=torch.int32).view(-1, 2).flip(0)
Expand All @@ -39,10 +29,10 @@ def pad(input, pad, mode="constant", value=None, *, pre_patch):
return pre_patch(input, pad, mode=mode, value=value)


def getitem(self, key, *, pre_patch):
if isinstance(key, Tensor) and "dml" in [self.device.type, key.device.type] and key.numel() == 1:
return pre_patch(self, int(key))
return pre_patch(self, key)
def layer_norm(input, normalized_shape, weight = None, bias = None, eps = 1e-05, *, pre_patch):
if input.device.type == "dml":
return pre_patch(input.contiguous(), normalized_shape, weight, bias, eps)
return pre_patch(input, normalized_shape, weight, bias, eps)


def retry_OOM(module):
Expand Down Expand Up @@ -110,17 +100,9 @@ def dml_patch_method(object, name, patched):
setattr(object, name, functools.partialmethod(patched, pre_patch=original))
active_dml_patches.append({"object": object, "name": name, "original": original})

# Not all places where the patches have an effect are necessarily listed.

# diffusers.models.attention_processor.Attention.get_attention_scores()
# diffusers.models.attention.AttentionBlock.forward()
# Diffusers implementation gives torch.empty() tensors with beta=0 to baddbmm(), which may contain NaNs.
# DML implementation doesn't properly ignore input argument with beta=0 and causes NaN propagation.
dml_patch(torch, "baddbmm", baddbmm)

dml_patch(torch.nn.functional, "pad", pad)
# DDIMScheduler.step(), PNDMScheduler.step(), No error messages or crashes, just may randomly freeze.
dml_patch_method(Tensor, "__getitem__", getitem)

dml_patch(torch.nn.functional, "layer_norm", layer_norm)

def decorate_forward(name, module):
"""Helper function to better find which modules DML fails in as it often does
Expand Down

0 comments on commit db13378

Please sign in to comment.