Skip to content

Commit

Permalink
Removing overzealous check that broke QMC. It turns out there are non…
Browse files Browse the repository at this point in the history
…-higher-order equations with more than 1 output.

PiperOrigin-RevId: 695298337
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 11, 2024
1 parent 0ccdfdd commit 8a610fc
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,8 @@ def clean_jaxpr(
params=params,
)

else:
assert all(outvar_is_dep_for_eqn) or not any(outvar_is_dep_for_eqn)
# else:
# assert all(outvar_is_dep_for_eqn) or not any(outvar_is_dep_for_eqn)

check = False

Expand Down Expand Up @@ -1054,6 +1054,50 @@ def clean_jaxpr(
return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr)


# Prototype for clean_jaxpr using JAX's dce_jaxpr. Doesn't work because
# dce_jaxpr will remove any equations with no used outputs, regardless of the
# dce_rule for that equation's primitive. Adding an "effect" to loss/layer
# tags also won't work, because we sometimes actually do want to remove them
# from the graph (when preserve_tags is False).
# def clean_jaxpr(
# jaxpr: J,
# preserve_tags: bool = True,
# ) -> J:
# """Runs dead code elimination on a Jaxpr, retaining loss and layer tags."""

# def dce_jaxpr_tag_rule(
# used_outputs: list[bool],
# eqn: JaxprEqn
# ) -> tuple[list[bool], JaxprEqn | None]:

# assert len(used_outputs) == len(eqn.outvars)

# if any(used_outputs) or preserve_tags:
# return [True] * len(eqn.invars), eqn
# else:
# return [False] * len(eqn.invars), None

# closed_jaxpr = to_closed_jaxpr(jaxpr)

# pe.dce_rules[tags.LossTag] = dce_jaxpr_tag_rule
# pe.dce_rules[tags.LayerTag] = dce_jaxpr_tag_rule

# cleaned_jaxpr, _ = pe.dce_jaxpr(
# closed_jaxpr.jaxpr,
# used_outputs=(True,) * len(closed_jaxpr.jaxpr.outvars),
# instantiate=True)

# pe.dce_rules.pop(tags.LossTag)
# pe.dce_rules.pop(tags.LayerTag)

# closed_jaxpr = ClosedJaxpr(
# jaxpr=cleaned_jaxpr,
# consts=closed_jaxpr.consts,
# )

# return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr)


def merge_broadcasts_jaxpr(jaxpr: J) -> J:
"""Merges consecutive broadcasts in the given Jaxpr."""

Expand Down

0 comments on commit 8a610fc

Please sign in to comment.