Skip to content

Commit

Permalink
DirectML Out of Memory Retry (#707)
Browse files Browse the repository at this point in the history
* out of memory retry

* clear traceback
  • Loading branch information
NullSenseStudio authored Oct 9, 2023
1 parent 4034e70 commit fb1d1d0
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 6 deletions.
3 changes: 2 additions & 1 deletion generator_process/actions/choose_device.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
import sys

def choose_device(self, optimizations) -> str:
Expand All @@ -13,7 +14,7 @@ def choose_device(self, optimizations) -> str:
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
if 'torch_directml' in sys.modules:
elif importlib.util.find_spec("torch_directml"):
import torch_directml
if torch_directml.is_available():
torch.utils.rename_privateuse1_backend("dml")
Expand Down
50 changes: 50 additions & 0 deletions generator_process/directml_patches.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import gc

import torch
from torch import Tensor
Expand Down Expand Up @@ -44,7 +45,56 @@ def getitem(self, key, *, pre_patch):
return pre_patch(self, key)


def retry_OOM(module):
if hasattr(module, "_retry_OOM"):
return
forward = module.forward

def is_OOM(e: RuntimeError):
if hasattr(e, "_retry_OOM"):
return False
if len(e.args) == 0:
return False
if not isinstance(e.args[0], str):
return False
return (
e.args[0].startswith("Could not allocate tensor with") and
e.args[0].endswith("bytes. There is not enough GPU video memory available!")
)

def wrapper(*args, **kwargs):
try:
try:
return forward(*args, **kwargs)
except RuntimeError as e:
if is_OOM(e):
tb = e.__traceback__.tb_next
while tb is not None:
# clear locals from traceback so that intermediate tensors can be garbage collected
# helps recover from Attention blocks more often
tb.tb_frame.clear()
tb = tb.tb_next
# print("retrying!", type(module).__name__)
gc.collect()
return forward(*args, **kwargs)
raise
except RuntimeError as e:
if is_OOM(e):
# only retry leaf modules
e._retry_OOM = True
raise

module.forward = wrapper
module._retry_OOM = True


def enable(pipe):
for comp in pipe.components.values():
if not isinstance(comp, torch.nn.Module):
continue
for module in comp.modules():
retry_OOM(module)

global active_dml_patches
if active_dml_patches is not None:
return
Expand Down
10 changes: 5 additions & 5 deletions generator_process/models/optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class Optimizations:
cudnn_benchmark: Annotated[bool, "cuda"] = False
tf32: Annotated[bool, "cuda"] = False
amp: Annotated[bool, "cuda"] = False
half_precision: Annotated[bool, {"cuda", "privateuseone"}] = True
cpu_offload: Annotated[str, {"cuda", "privateuseone"}] = CPUOffload.OFF
half_precision: Annotated[bool, {"cuda", "dml"}] = True
cpu_offload: Annotated[str, {"cuda", "dml"}] = CPUOffload.OFF
channels_last_memory_format: bool = False
sdp_attention: Annotated[bool, {"cpu", "cuda", "mps"}] = True
sdp_attention: bool = True
batch_size: int = 1
vae_slicing: bool = True
vae_tiling: str = "off"
Expand All @@ -43,7 +43,7 @@ def infer_device() -> str:
if sys.platform == "darwin":
return "mps"
elif os.path.exists(absolute_path(".python_dependencies/torch_directml")):
return "privateuseone"
return "dml"
else:
return "cuda"

Expand Down Expand Up @@ -165,7 +165,7 @@ def apply(self, pipeline, device):
except: pass

from .. import directml_patches
if device == "privateuseone":
if device == "dml":
directml_patches.enable(pipeline)
else:
directml_patches.disable(pipeline)
Expand Down

0 comments on commit fb1d1d0

Please sign in to comment.