Skip to content

Commit

Permalink
improve debug print for chains
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Oct 10, 2023
1 parent a663375 commit 0e1600f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 25 deletions.
114 changes: 95 additions & 19 deletions src/refiners/fluxion/layers/chain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from collections import defaultdict
import inspect
import re
import sys
import traceback
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
import torch
from torch import Tensor, cat, device as Device, dtype as DType
from refiners.fluxion.layers.basics import Identity
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
from refiners.fluxion.layers.module import Module, ContextModule, ModuleTree, WeightedModule
from refiners.fluxion.context import Contexts, ContextProvider
from refiners.fluxion.utils import summarize_tensor


T = TypeVar("T", bound=Module)
Expand Down Expand Up @@ -109,6 +114,13 @@ def structural_copy(m: T) -> T:
return m.structural_copy() if isinstance(m, ContextModule) else m


class ChainError(RuntimeError):
"""Exception raised when an error occurs during the execution of a Chain."""

def __init__(self, message: str, /) -> None:
super().__init__(message)


class Chain(ContextModule):
_modules: dict[str, Module]
_provider: ContextProvider
Expand Down Expand Up @@ -173,34 +185,98 @@ def set_context(self, context: str, value: Any) -> None:
self._provider.set_context(context, value)
self._register_provider()

def debug_repr(self, layer_name: str = "") -> str:
lines: list[str] = []
tab = " "
tab_length = 0
for i, parent in enumerate(self.get_parents()[::-1]):
lines.append(f"{tab*tab_length}{'└─ ' if i else ''}{parent.__class__.__name__}")
tab_length += 1
def _show_error_in_tree(self, name: str, /, max_lines: int = 20) -> str:
tree = ModuleTree(module=self)
classname_counter: dict[str, int] = defaultdict(int)
first_ancestor = self.get_parents()[-1] if self.get_parents() else self

def find_state_dict_key(module: Module, /) -> str | None:
for key, layer in module.named_modules():
if layer == self:
return ".".join((key, name))
return None

lines.append(f"{tab*tab_length}└─ {self.__class__.__name__}")
for child in tree:
classname, count = name.rsplit(sep="_", maxsplit=1) if "_" in name else (name, "1")
if child["class_name"] == classname:
classname_counter[classname] += 1
if classname_counter[classname] == int(count):
state_dict_key = find_state_dict_key(first_ancestor)
child["value"] = f">>> {child['value']} | {state_dict_key}"
break

for name, _ in self._modules.items():
error_arrow = "⚠️" if name == layer_name else ""
lines.append(f"{tab*tab_length} | {name} {error_arrow}")
tree_repr = tree._generate_tree_repr(tree.root, depth=3) # type: ignore[reportPrivateUsage]

return "\n".join(lines)
lines = tree_repr.split(sep="\n")
error_line_idx = next((idx for idx, line in enumerate(iterable=lines) if line.startswith(">>>")), 0)

return ModuleTree.shorten_tree_repr(tree_repr, line_index=error_line_idx, max_lines=max_lines)

@staticmethod
def _pretty_print_args(*args: Any) -> str:
"""
Flatten nested tuples and print tensors with their shape and other informations.
"""

def call_layer(self, layer: Module, layer_name: str, *args: Any):
def _flatten_tuple(t: Tensor | tuple[Any, ...], /) -> list[Any]:
if isinstance(t, tuple):
return [item for subtuple in t for item in _flatten_tuple(subtuple)]
else:
return [t]

flat_args = _flatten_tuple(args)

return "\n".join(
[
f"{idx}: {summarize_tensor(arg) if isinstance(arg, Tensor) else arg}"
for idx, arg in enumerate(iterable=flat_args)
]
)

def _filter_traceback(self, *frames: traceback.FrameSummary) -> list[traceback.FrameSummary]:
patterns_to_exclude = [
(r"torch/nn/modules/", r"^_call_impl$"),
(r"torch/nn/functional\.py", r""),
(r"refiners/fluxion/layers/", r"^_call_layer$"),
(r"refiners/fluxion/layers/", r"^forward$"),
(r"refiners/fluxion/layers/chain\.py", r""),
(r"", r"^_"),
]

def should_exclude(frame: traceback.FrameSummary, /) -> bool:
for filename_pattern, name_pattern in patterns_to_exclude:
if re.search(pattern=filename_pattern, string=frame.filename) and re.search(
pattern=name_pattern, string=frame.name
):
return True
return False

return [frame for frame in frames if not should_exclude(frame)]

def _call_layer(self, layer: Module, name: str, /, *args: Any) -> Any:
try:
return layer(*args)
except Exception as e:
pretty_print = self.debug_repr(layer_name)
raise ValueError(f"Error in layer {layer_name}, args:\n {args}\n \n{pretty_print}") from e
exc_type, _, exc_traceback = sys.exc_info()
assert exc_type
tb_list = traceback.extract_tb(tb=exc_traceback)
filtered_tb_list = self._filter_traceback(*tb_list)
formatted_tb = "".join(traceback.format_list(extracted_list=filtered_tb_list))
pretty_args = Chain._pretty_print_args(args)
error_tree = self._show_error_in_tree(name)

exception_str = re.sub(pattern=r"\n\s*\n", repl="\n", string=str(object=e))
message = f"{formatted_tb}\n{exception_str}\n---------------\n{error_tree}\n{pretty_args}"
if "Error" not in exception_str:
message = f"{exc_type.__name__}:\n {message}"

raise ChainError(message) from None

def forward(self, *args: Any) -> Any:
result: tuple[Any] | Any = None
intermediate_args: tuple[Any, ...] = args
for name, layer in self._modules.items():
result = self.call_layer(layer, name, *intermediate_args)
result = self._call_layer(layer, name, *intermediate_args)
intermediate_args = (result,) if not isinstance(result, tuple) else result

self._reset_context()
Expand Down Expand Up @@ -409,7 +485,7 @@ class Parallel(Chain):
_tag = "PAR"

def forward(self, *args: Any) -> tuple[Tensor, ...]:
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])

def _show_only_tag(self) -> bool:
return self.__class__ == Parallel
Expand All @@ -421,7 +497,7 @@ class Distribute(Chain):
def forward(self, *args: Any) -> tuple[Tensor, ...]:
n, m = len(args), len(self._modules)
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])

def _show_only_tag(self) -> bool:
return self.__class__ == Distribute
Expand Down
24 changes: 18 additions & 6 deletions src/refiners/fluxion/layers/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __repr__(self) -> str:

def pretty_print(self, depth: int = -1) -> None:
tree = ModuleTree(module=self)
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth))
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]

def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
"""Return a dictionary of basic attributes of the module.
Expand Down Expand Up @@ -182,10 +182,22 @@ def __str__(self) -> str:
return f"{self.__class__.__name__}(root={self.root['value']})"

def __repr__(self) -> str:
return self.generate_tree_repr(node=self.root, is_root=True, depth=7)

def generate_tree_repr(
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
return self._generate_tree_repr(self.root, is_root=True, depth=7)

def __iter__(self) -> Generator[TreeNode, None, None]:
for child in self.root["children"]:
yield child

@classmethod
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
"""Shorten the tree representation to a given number of lines around a given line index."""
lines = tree_repr.split(sep="\n")
start_idx = max(0, line_index - max_lines // 2)
end_idx = min(len(lines), line_index + max_lines // 2 + 1)
return "\n".join(lines[start_idx:end_idx])

def _generate_tree_repr(
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
) -> str:
if depth == 0 and node["children"]:
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
Expand All @@ -211,7 +223,7 @@ def generate_tree_repr(
else:
child_value = child["value"]

child_str = self.generate_tree_repr(
child_str = self._generate_tree_repr(
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
prefix=prefix + new_prefix,
is_last=i == len(node["children"]) - 1,
Expand Down
20 changes: 20 additions & 0 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,23 @@ def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:

def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
_save_file(tensors, path, metadata) # type: ignore


def summarize_tensor(tensor: torch.Tensor, /) -> str:
return (
"Tensor("
+ ", ".join(
[
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
f"mean={tensor.mean():.2f}",
f"std={tensor.std():.2f}",
f"norm={norm(x=tensor):.2f}",
f"grad={tensor.requires_grad}",
]
)
+ ")"
)
16 changes: 16 additions & 0 deletions tests/fluxion/layers/test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,19 @@ def test_setattr_dont_register() -> None:
chain.foo = fl.Linear(in_features=1, out_features=1)

assert module_keys(chain=chain) == ["Linear_1", "Linear_2"]


EXPECTED_TREE = (
"(CHAIN)\n ├── Linear(in_features=1, out_features=1) (x2)\n └── (CHAIN)\n ├── Linear(in_features=1,"
" out_features=1) #1\n └── Linear(in_features=2, out_features=1) #2"
)


def test_debug_print() -> None:
chain = fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(2, 1)),
)

assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]

0 comments on commit 0e1600f

Please sign in to comment.