diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index 133cba0..42b9e16 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -1019,8 +1019,8 @@ def clean_jaxpr( params=params, ) - else: - assert all(outvar_is_dep_for_eqn) or not any(outvar_is_dep_for_eqn) + # else: + # assert all(outvar_is_dep_for_eqn) or not any(outvar_is_dep_for_eqn) check = False @@ -1054,6 +1054,50 @@ def clean_jaxpr( return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr) +# Prototype for clean_jaxpr using JAX's dce_jaxpr. Doesn't work because +# dce_jaxpr will remove any equations with no used outputs, regardless of the +# dce_rule for that equation's primitive. Adding an "effect" to loss/layer +# tags also won't work, because we sometimes actually do want to remove them +# from the graph (when preserve_tags is False). +# def clean_jaxpr( +# jaxpr: J, +# preserve_tags: bool = True, +# ) -> J: +# """Runs dead code elimination on a Jaxpr, retaining loss and layer tags.""" + +# def dce_jaxpr_tag_rule( +# used_outputs: list[bool], +# eqn: JaxprEqn +# ) -> tuple[list[bool], JaxprEqn | None]: + +# assert len(used_outputs) == len(eqn.outvars) + +# if any(used_outputs) or preserve_tags: +# return [True] * len(eqn.invars), eqn +# else: +# return [False] * len(eqn.invars), None + +# closed_jaxpr = to_closed_jaxpr(jaxpr) + +# pe.dce_rules[tags.LossTag] = dce_jaxpr_tag_rule +# pe.dce_rules[tags.LayerTag] = dce_jaxpr_tag_rule + +# cleaned_jaxpr, _ = pe.dce_jaxpr( +# closed_jaxpr.jaxpr, +# used_outputs=(True,) * len(closed_jaxpr.jaxpr.outvars), +# instantiate=True) + +# pe.dce_rules.pop(tags.LossTag) +# pe.dce_rules.pop(tags.LayerTag) + +# closed_jaxpr = ClosedJaxpr( +# jaxpr=cleaned_jaxpr, +# consts=closed_jaxpr.consts, +# ) + +# return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr) + + def merge_broadcasts_jaxpr(jaxpr: J) -> J: """Merges consecutive broadcasts in the given Jaxpr."""