Skip to content

Commit

Permalink
add fit(method="tree") fix ALS for complex TNs
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Dec 13, 2024
1 parent 5b18c43 commit 464bb36
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 149 deletions.
2 changes: 1 addition & 1 deletion quimb/tensor/contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import contextlib
import collections

import cotengra as ctg
import cotengra as ctg


_CONTRACT_STRATEGY = 'greedy'
Expand Down
238 changes: 194 additions & 44 deletions quimb/tensor/fitting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Tools for computing distances between and fitting tensor networks."""
from autoray import dag, do

from .contraction import contract_strategy
from autoray import backend_like, dag, do

from ..utils import check_opt
from .contraction import contract_strategy


def tensor_network_distance(
Expand Down Expand Up @@ -49,10 +50,13 @@ def tensor_network_distance(
directly formed and the norm computed, which can be quicker when the
exterior dimensions are small. If ``'auto'``, the dense method will
be used if the total operator (outer) size is ``<= 2**16``.
normalized : bool, optional
normalized : bool or str, optional
If ``True``, then normalize the distance by the norm of the two
operators, i.e. ``2 * D(A, B) / (|A| + |B|)``. The resulting distance
operators, i.e. ``D(A, B) * 2 / (|A| + |B|)``. The resulting distance
lies between 0 and 2 and is more useful for assessing convergence.
If ``'infidelity'``, compute the normalized infidelity
``1 - |<A|B>|^2 / (|A| |B|)``, which can be faster to optimize e.g.,
but does not take into account normalization.
contract_opts
Supplied to :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract`.
Expand Down Expand Up @@ -81,25 +85,39 @@ def tensor_network_distance(

# directly from vectorizations of both
if method == "dense":
tnA = tnA.contract(..., output_inds=oix, preserve_tensor=True)
tnB = tnB.contract(..., output_inds=oix, preserve_tensor=True)
tnA = tnA.contract(output_inds=oix, preserve_tensor=True)
tnB = tnB.contract(output_inds=oix, preserve_tensor=True)

# overlap method
if xAA is None:
xAA = (tnA | tnA.H).contract(..., **contract_opts)
xAA = (tnA | tnA.H).contract(**contract_opts)
if xAB is None:
xAB = (tnA | tnB.H).contract(..., **contract_opts)
xAB = (tnA | tnB.H).contract(**contract_opts)
if xBB is None:
xBB = (tnB | tnB.H).contract(..., **contract_opts)
xBB = (tnB | tnB.H).contract(**contract_opts)

dAB = do("abs", xAA - 2 * do("real", xAB) + xBB) ** 0.5
if normalized == "infidelity":
# compute normalized infidelity
return 1 - do("abs", xAB**2 / (xAA * xBB))

if normalized:
dAB *= 2 / (do("abs", xAA)**0.5 + do("abs", xBB)**0.5)
if normalized == "infidelity_sqrt":
# compute normalized sqrt infidelity
return 1 - do("abs", xAB / (xAA * xBB) ** 0.5)

return dAB
if normalized == "squared":
return (
do("abs", xAA + xBB - 2 * do("real", xAB))
# divide by average norm-squared of A and B
* 2 / (do("abs", xAA) + do("abs", xBB))
) ** 0.5

dAB = do("abs", xAA + xBB - 2 * do("real", xAB)) ** 0.5

if normalized:
# divide by average norm of A and B
dAB = dAB * 2 / (do("abs", xAA) ** 0.5 + do("abs", xBB) ** 0.5)

return dAB


def tensor_network_fit_autodiff(
Expand All @@ -110,6 +128,7 @@ def tensor_network_fit_autodiff(
autodiff_backend="autograd",
contract_optimize="auto-hq",
distance_method="auto",
normalized="squared",
inplace=False,
progbar=False,
**kwargs,
Expand Down Expand Up @@ -151,7 +170,6 @@ def tensor_network_fit_autodiff(
from .tensor_core import tensor_network_distance

xBB = (tn_target | tn_target.H).contract(
...,
output_inds=(),
optimize=contract_optimize,
)
Expand All @@ -160,7 +178,11 @@ def tensor_network_fit_autodiff(
tn=tn,
loss_fn=tensor_network_distance,
loss_constants={"tnB": tn_target, "xBB": xBB},
loss_kwargs={"method": distance_method, "optimize": contract_optimize},
loss_kwargs={
"method": distance_method,
"optimize": contract_optimize,
"normalized": normalized,
},
autodiff_backend=autodiff_backend,
progbar=progbar,
**kwargs,
Expand Down Expand Up @@ -192,8 +214,10 @@ def _tn_fit_als_core(
):
from .tensor_core import group_inds

backend = next(iter(tnAA.tensors)).backend

# shared intermediates + greedy = good reuse of contractions
with contract_strategy(contract_optimize):
with contract_strategy(contract_optimize), backend_like(backend):
# prepare each of the contractions we are going to repeat
env_contractions = []
for tg in var_tags:
Expand All @@ -202,17 +226,17 @@ def _tn_fit_als_core(
tb = tnAA["__BRA__", tg]

# get inds, and ensure any bonds come last, for linalg.solve
lix, bix, rix = group_inds(tb, tk)
tk.transpose_(*rix, *bix)
tb.transpose_(*lix, *bix)
lix, bix, rix = group_inds(tk, tb)
tk.transpose_(*lix, *bix)
tb.transpose_(*rix, *bix)

# form TNs with 'holes', i.e. environment tensors networks
A_tn = tnAA.select((tg,), "!all")
y_tn = tnAB.select((tg,), "!all")

env_contractions.append((tk, tb, lix, bix, rix, A_tn, y_tn))

if tol != 0.0:
if tol != 0.0 or progbar:
old_d = float("inf")

if progbar:
Expand All @@ -225,37 +249,40 @@ def _tn_fit_als_core(
# the main iterative sweep on each tensor, locally optimizing
for _ in pbar:
for tk, tb, lix, bix, rix, A_tn, y_tn in env_contractions:
Ni = A_tn.to_dense(lix, rix)
Wi = y_tn.to_dense(rix, bix)
# form local normalization and local overlap
Ni = A_tn.to_dense(rix, lix)
bi = y_tn.to_dense(rix, bix)

if enforce_pos:
el, ev = do("linalg.eigh", Ni)
el = do("clip", el, el[-1] * pos_smudge, None)
Ni_p = ev * do("reshape", el, (1, -1)) @ dag(ev)
el, V = do("linalg.eigh", Ni)
elmax = do("max", el)
el = do("clip", el, elmax * pos_smudge, None)
# can solve directly using eigendecomposition
x = V @ ((dag(V) @ bi) / do("reshape", el, (-1, 1)))
else:
Ni_p = Ni

if solver == "solve":
x = do("linalg.solve", Ni_p, Wi)
elif solver == "lstsq":
x = do("linalg.lstsq", Ni_p, Wi, rcond=pos_smudge)[0]
if solver == "solve":
x = do("linalg.solve", Ni_p, bi)
elif solver == "lstsq":
x = do("linalg.lstsq", Ni_p, bi, rcond=pos_smudge)[0]

x_r = do("reshape", x, tk.shape)
# n.b. because we are using virtual TNs -> updates propagate
tk.modify(data=x_r)
tb.modify(data=do("conj", x_r))

# assess | A - B | for convergence or printing
# assess | A - B | (normalized) for convergence or printing
if (tol != 0.0) or progbar:
xAA = do("trace", dag(x) @ (Ni @ x)) # <A|A>
xAB = do("trace", do("real", dag(x) @ Wi)) # <A|B>
d = do("abs", (xAA - 2 * xAB + xBB)) ** 0.5
dagx = dag(x)
xAA = do("trace", do("real", dagx @ (Ni @ x))) # <A|A>
xAB = do("trace", do("real", dagx @ bi)) # <A|B>
d = abs(xAA + xBB - 2 * xAB) ** 0.5 * 2 / (xAA**0.5 + xBB**0.5)
if abs(d - old_d) < tol:
break
old_d = d

if progbar:
pbar.set_description(str(d))
pbar.set_description(f"{d:.3g}")


def tensor_network_fit_als(
Expand Down Expand Up @@ -329,13 +356,13 @@ def tensor_network_fit_als(
tensor_network_fit_autodiff, tensor_network_distance
"""
# mark the tensors we are going to optimize
tna = tn.copy()
tna.add_tag("__KET__")
tn_fit = tn.copy()
tn_fit.add_tag("__KET__")

if tags is None:
to_tag = tna
to_tag = tn_fit
else:
to_tag = tna.select_tensors(tags, "any")
to_tag = tn_fit.select_tensors(tags, "any")

var_tags = []
for i, t in enumerate(to_tag):
Expand All @@ -344,15 +371,15 @@ def tensor_network_fit_als(
var_tags.append(var_tag)

# form the norm of the varying TN (A) and its overlap with the target (B)
tn_fit_conj = tn_fit.conj().retag_({"__KET__": "__BRA__"})
if tnAA is None:
tnAA = tna | tna.H.retag_({"__KET__": "__BRA__"})
tnAA = tn_fit | tn_fit_conj
if tnAB is None:
tnAB = tna | tn_target.H
tnAB = tn_target | tn_fit_conj

if (tol != 0.0) and (xBB is None):
if (tol != 0.0 or progbar) and (xBB is None):
# <B|B>
xBB = (tn_target | tn_target.H).contract(
...,
optimize=contract_optimize,
output_inds=(),
)
Expand All @@ -377,9 +404,132 @@ def tensor_network_fit_als(
if not inplace:
tn = tn.copy()

for t1, t2 in zip(tn, tna):
for t1, t2 in zip(tn, tn_fit):
# transpose so only thing changed in original TN is data
t2.transpose_like_(t1)
t1.modify(data=t2.data)

return tn


def tensor_network_fit_tree(
tn,
tn_target,
tags=None,
steps=100,
tol=1e-9,
ordering=None,
xBB=None,
istree=True,
contract_optimize="auto-hq",
inplace=False,
progbar=False,
):
"""Fit `tn` to `tn_target`, assuming that `tn` has tree structure (i.e. a
single path between any two sites) and matching outer structure to
`tn_target`. The tree structure allows a canonical form that greatly
simplifies the normalization and least squares minimization. Note that no
structure is assumed about `tn_target`, and so for example no partial
contractions reused.
Parameters
----------
tn : TensorNetwork
The tensor network to fit, it should have a tree structure and outer
indices matching `tn_target`.
tn_target : TensorNetwork
The target tensor network to fit ``tn`` to.
tags : sequence of str, optional
If supplied, only optimize tensors matching any of given tags.
steps : int, optional
The maximum number of ALS steps.
tol : float, optional
The target norm distance.
ordering : sequence of int, optional
The order in which to optimize the tensors, if None will be computed
automatically using a hierarchical clustering.
xBB : float, optional
If you have already know, have computed ``tn_target.H @ tn_target``,
or it doesn't matter, you can supply the value here. It matters only
for the overall scale of the norm distance.
contract_optimize : str, optional
A contraction path strategy or optimizer for contracting the local
environments.
inplace : bool, optional
Fit ``tn`` in place.
progbar : bool, optional
Show a live progress bar of the fitting process.
Returns
-------
TensorNetwork
"""
if xBB is None:
xBB = (tn_target | tn_target.H).contract(
optimize=contract_optimize,
output_inds=(),
)

tn_fit = tn.conj(inplace=inplace)
tnAB = tn_fit | tn_target

if ordering is None:
if tags is not None:
tids = tn_fit._get_tids_from_tags(tags, "any")
else:
tids = None
ordering = tn_fit.compute_hierarchical_ordering(tids)

# prepare contractions
env_contractions = []
for i, tid in enumerate(ordering):
tn_hole = tnAB.copy(virtual=True)
ti = tn_hole.pop_tensor(tid)
# we'll need to canonicalize along path from the last tid to this one
tid_prev = ordering[(i - 1) % len(ordering)]
path = tn_fit.get_path_between_tids(tid_prev, tid)
canon_pairs = [
(path.tids[j], path.tids[j + 1]) for j in range(len(path))
]
env_contractions.append((tid, tn_hole, ti, canon_pairs))

# initial canonicalization around first tensor
tn_fit._canonize_around_tids([ordering[0]])

if progbar:
import tqdm

pbar = tqdm.trange(steps)
else:
pbar = range(steps)

Check warning on line 504 in quimb/tensor/fitting.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/fitting.py#L504

Added line #L504 was not covered by tests

old_d = float("inf")

for _ in pbar:
for tid, tn_hole, ti, canon_pairs in env_contractions:
if istree:
# move canonical center to tid
for tidi, tidj in canon_pairs:
tn_fit._canonize_between_tids(tidi, tidj)
else:
# pseudo canonicalization
tn_fit._canonize_around_tids([tid])

Check warning on line 516 in quimb/tensor/fitting.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/fitting.py#L516

Added line #L516 was not covered by tests

# get the new conjugate tensor
ti_new = tn_hole.contract(output_inds=ti.inds, optimize="auto-hq")
ti_new.conj_()
# modify the data
ti.modify(data=ti_new.data)

if tol != 0.0 or progbar:
# canonicalized form enable simpler distance computation
xAA = ti.norm() ** 2 # == xAB
d = 2 * abs(xBB - xAA) ** 0.5 / (xBB**0.5 + xAA**0.5)
if abs(d - old_d) < tol:
break
old_d = d

if progbar:
pbar.set_description(f"{d:.3g}")

return tn_fit.conj_()
Loading

0 comments on commit 464bb36

Please sign in to comment.