From 464bb36e1210d8d877de9ce8cac9db1c47c5b13f Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Fri, 13 Dec 2024 15:56:49 -0800 Subject: [PATCH] add fit(method="tree") fix ALS for complex TNs --- quimb/tensor/contraction.py | 2 +- quimb/tensor/fitting.py | 238 +++++++++++++++++++++----- quimb/tensor/tensor_core.py | 35 +++- tests/test_tensor/test_fitting.py | 122 +++++++++++++ tests/test_tensor/test_tensor_core.py | 95 ---------- 5 files changed, 343 insertions(+), 149 deletions(-) create mode 100644 tests/test_tensor/test_fitting.py diff --git a/quimb/tensor/contraction.py b/quimb/tensor/contraction.py index f7a445d0..028ae2d2 100644 --- a/quimb/tensor/contraction.py +++ b/quimb/tensor/contraction.py @@ -6,7 +6,7 @@ import contextlib import collections -import cotengra as ctg +import cotengra as ctg _CONTRACT_STRATEGY = 'greedy' diff --git a/quimb/tensor/fitting.py b/quimb/tensor/fitting.py index b11e6539..b394fde6 100644 --- a/quimb/tensor/fitting.py +++ b/quimb/tensor/fitting.py @@ -1,8 +1,9 @@ """Tools for computing distances between and fitting tensor networks.""" -from autoray import dag, do -from .contraction import contract_strategy +from autoray import backend_like, dag, do + from ..utils import check_opt +from .contraction import contract_strategy def tensor_network_distance( @@ -49,10 +50,13 @@ def tensor_network_distance( directly formed and the norm computed, which can be quicker when the exterior dimensions are small. If ``'auto'``, the dense method will be used if the total operator (outer) size is ``<= 2**16``. - normalized : bool, optional + normalized : bool or str, optional If ``True``, then normalize the distance by the norm of the two - operators, i.e. ``2 * D(A, B) / (|A| + |B|)``. The resulting distance + operators, i.e. ``D(A, B) * 2 / (|A| + |B|)``. The resulting distance lies between 0 and 2 and is more useful for assessing convergence. + If ``'infidelity'``, compute the normalized infidelity + ``1 - ||^2 / (|A| |B|)``, which can be faster to optimize e.g., + but does not take into account normalization. contract_opts Supplied to :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract`. @@ -81,25 +85,39 @@ def tensor_network_distance( # directly from vectorizations of both if method == "dense": - tnA = tnA.contract(..., output_inds=oix, preserve_tensor=True) - tnB = tnB.contract(..., output_inds=oix, preserve_tensor=True) + tnA = tnA.contract(output_inds=oix, preserve_tensor=True) + tnB = tnB.contract(output_inds=oix, preserve_tensor=True) # overlap method if xAA is None: - xAA = (tnA | tnA.H).contract(..., **contract_opts) + xAA = (tnA | tnA.H).contract(**contract_opts) if xAB is None: - xAB = (tnA | tnB.H).contract(..., **contract_opts) + xAB = (tnA | tnB.H).contract(**contract_opts) if xBB is None: - xBB = (tnB | tnB.H).contract(..., **contract_opts) + xBB = (tnB | tnB.H).contract(**contract_opts) - dAB = do("abs", xAA - 2 * do("real", xAB) + xBB) ** 0.5 + if normalized == "infidelity": + # compute normalized infidelity + return 1 - do("abs", xAB**2 / (xAA * xBB)) - if normalized: - dAB *= 2 / (do("abs", xAA)**0.5 + do("abs", xBB)**0.5) + if normalized == "infidelity_sqrt": + # compute normalized sqrt infidelity + return 1 - do("abs", xAB / (xAA * xBB) ** 0.5) - return dAB + if normalized == "squared": + return ( + do("abs", xAA + xBB - 2 * do("real", xAB)) + # divide by average norm-squared of A and B + * 2 / (do("abs", xAA) + do("abs", xBB)) + ) ** 0.5 + + dAB = do("abs", xAA + xBB - 2 * do("real", xAB)) ** 0.5 + if normalized: + # divide by average norm of A and B + dAB = dAB * 2 / (do("abs", xAA) ** 0.5 + do("abs", xBB) ** 0.5) + return dAB def tensor_network_fit_autodiff( @@ -110,6 +128,7 @@ def tensor_network_fit_autodiff( autodiff_backend="autograd", contract_optimize="auto-hq", distance_method="auto", + normalized="squared", inplace=False, progbar=False, **kwargs, @@ -151,7 +170,6 @@ def tensor_network_fit_autodiff( from .tensor_core import tensor_network_distance xBB = (tn_target | tn_target.H).contract( - ..., output_inds=(), optimize=contract_optimize, ) @@ -160,7 +178,11 @@ def tensor_network_fit_autodiff( tn=tn, loss_fn=tensor_network_distance, loss_constants={"tnB": tn_target, "xBB": xBB}, - loss_kwargs={"method": distance_method, "optimize": contract_optimize}, + loss_kwargs={ + "method": distance_method, + "optimize": contract_optimize, + "normalized": normalized, + }, autodiff_backend=autodiff_backend, progbar=progbar, **kwargs, @@ -192,8 +214,10 @@ def _tn_fit_als_core( ): from .tensor_core import group_inds + backend = next(iter(tnAA.tensors)).backend + # shared intermediates + greedy = good reuse of contractions - with contract_strategy(contract_optimize): + with contract_strategy(contract_optimize), backend_like(backend): # prepare each of the contractions we are going to repeat env_contractions = [] for tg in var_tags: @@ -202,9 +226,9 @@ def _tn_fit_als_core( tb = tnAA["__BRA__", tg] # get inds, and ensure any bonds come last, for linalg.solve - lix, bix, rix = group_inds(tb, tk) - tk.transpose_(*rix, *bix) - tb.transpose_(*lix, *bix) + lix, bix, rix = group_inds(tk, tb) + tk.transpose_(*lix, *bix) + tb.transpose_(*rix, *bix) # form TNs with 'holes', i.e. environment tensors networks A_tn = tnAA.select((tg,), "!all") @@ -212,7 +236,7 @@ def _tn_fit_als_core( env_contractions.append((tk, tb, lix, bix, rix, A_tn, y_tn)) - if tol != 0.0: + if tol != 0.0 or progbar: old_d = float("inf") if progbar: @@ -225,37 +249,40 @@ def _tn_fit_als_core( # the main iterative sweep on each tensor, locally optimizing for _ in pbar: for tk, tb, lix, bix, rix, A_tn, y_tn in env_contractions: - Ni = A_tn.to_dense(lix, rix) - Wi = y_tn.to_dense(rix, bix) + # form local normalization and local overlap + Ni = A_tn.to_dense(rix, lix) + bi = y_tn.to_dense(rix, bix) if enforce_pos: - el, ev = do("linalg.eigh", Ni) - el = do("clip", el, el[-1] * pos_smudge, None) - Ni_p = ev * do("reshape", el, (1, -1)) @ dag(ev) + el, V = do("linalg.eigh", Ni) + elmax = do("max", el) + el = do("clip", el, elmax * pos_smudge, None) + # can solve directly using eigendecomposition + x = V @ ((dag(V) @ bi) / do("reshape", el, (-1, 1))) else: Ni_p = Ni - - if solver == "solve": - x = do("linalg.solve", Ni_p, Wi) - elif solver == "lstsq": - x = do("linalg.lstsq", Ni_p, Wi, rcond=pos_smudge)[0] + if solver == "solve": + x = do("linalg.solve", Ni_p, bi) + elif solver == "lstsq": + x = do("linalg.lstsq", Ni_p, bi, rcond=pos_smudge)[0] x_r = do("reshape", x, tk.shape) # n.b. because we are using virtual TNs -> updates propagate tk.modify(data=x_r) tb.modify(data=do("conj", x_r)) - # assess | A - B | for convergence or printing + # assess | A - B | (normalized) for convergence or printing if (tol != 0.0) or progbar: - xAA = do("trace", dag(x) @ (Ni @ x)) # - xAB = do("trace", do("real", dag(x) @ Wi)) # - d = do("abs", (xAA - 2 * xAB + xBB)) ** 0.5 + dagx = dag(x) + xAA = do("trace", do("real", dagx @ (Ni @ x))) # + xAB = do("trace", do("real", dagx @ bi)) # + d = abs(xAA + xBB - 2 * xAB) ** 0.5 * 2 / (xAA**0.5 + xBB**0.5) if abs(d - old_d) < tol: break old_d = d if progbar: - pbar.set_description(str(d)) + pbar.set_description(f"{d:.3g}") def tensor_network_fit_als( @@ -329,13 +356,13 @@ def tensor_network_fit_als( tensor_network_fit_autodiff, tensor_network_distance """ # mark the tensors we are going to optimize - tna = tn.copy() - tna.add_tag("__KET__") + tn_fit = tn.copy() + tn_fit.add_tag("__KET__") if tags is None: - to_tag = tna + to_tag = tn_fit else: - to_tag = tna.select_tensors(tags, "any") + to_tag = tn_fit.select_tensors(tags, "any") var_tags = [] for i, t in enumerate(to_tag): @@ -344,15 +371,15 @@ def tensor_network_fit_als( var_tags.append(var_tag) # form the norm of the varying TN (A) and its overlap with the target (B) + tn_fit_conj = tn_fit.conj().retag_({"__KET__": "__BRA__"}) if tnAA is None: - tnAA = tna | tna.H.retag_({"__KET__": "__BRA__"}) + tnAA = tn_fit | tn_fit_conj if tnAB is None: - tnAB = tna | tn_target.H + tnAB = tn_target | tn_fit_conj - if (tol != 0.0) and (xBB is None): + if (tol != 0.0 or progbar) and (xBB is None): # xBB = (tn_target | tn_target.H).contract( - ..., optimize=contract_optimize, output_inds=(), ) @@ -377,9 +404,132 @@ def tensor_network_fit_als( if not inplace: tn = tn.copy() - for t1, t2 in zip(tn, tna): + for t1, t2 in zip(tn, tn_fit): # transpose so only thing changed in original TN is data t2.transpose_like_(t1) t1.modify(data=t2.data) return tn + + +def tensor_network_fit_tree( + tn, + tn_target, + tags=None, + steps=100, + tol=1e-9, + ordering=None, + xBB=None, + istree=True, + contract_optimize="auto-hq", + inplace=False, + progbar=False, +): + """Fit `tn` to `tn_target`, assuming that `tn` has tree structure (i.e. a + single path between any two sites) and matching outer structure to + `tn_target`. The tree structure allows a canonical form that greatly + simplifies the normalization and least squares minimization. Note that no + structure is assumed about `tn_target`, and so for example no partial + contractions reused. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to fit, it should have a tree structure and outer + indices matching `tn_target`. + tn_target : TensorNetwork + The target tensor network to fit ``tn`` to. + tags : sequence of str, optional + If supplied, only optimize tensors matching any of given tags. + steps : int, optional + The maximum number of ALS steps. + tol : float, optional + The target norm distance. + ordering : sequence of int, optional + The order in which to optimize the tensors, if None will be computed + automatically using a hierarchical clustering. + xBB : float, optional + If you have already know, have computed ``tn_target.H @ tn_target``, + or it doesn't matter, you can supply the value here. It matters only + for the overall scale of the norm distance. + contract_optimize : str, optional + A contraction path strategy or optimizer for contracting the local + environments. + inplace : bool, optional + Fit ``tn`` in place. + progbar : bool, optional + Show a live progress bar of the fitting process. + + Returns + ------- + TensorNetwork + """ + if xBB is None: + xBB = (tn_target | tn_target.H).contract( + optimize=contract_optimize, + output_inds=(), + ) + + tn_fit = tn.conj(inplace=inplace) + tnAB = tn_fit | tn_target + + if ordering is None: + if tags is not None: + tids = tn_fit._get_tids_from_tags(tags, "any") + else: + tids = None + ordering = tn_fit.compute_hierarchical_ordering(tids) + + # prepare contractions + env_contractions = [] + for i, tid in enumerate(ordering): + tn_hole = tnAB.copy(virtual=True) + ti = tn_hole.pop_tensor(tid) + # we'll need to canonicalize along path from the last tid to this one + tid_prev = ordering[(i - 1) % len(ordering)] + path = tn_fit.get_path_between_tids(tid_prev, tid) + canon_pairs = [ + (path.tids[j], path.tids[j + 1]) for j in range(len(path)) + ] + env_contractions.append((tid, tn_hole, ti, canon_pairs)) + + # initial canonicalization around first tensor + tn_fit._canonize_around_tids([ordering[0]]) + + if progbar: + import tqdm + + pbar = tqdm.trange(steps) + else: + pbar = range(steps) + + old_d = float("inf") + + for _ in pbar: + for tid, tn_hole, ti, canon_pairs in env_contractions: + if istree: + # move canonical center to tid + for tidi, tidj in canon_pairs: + tn_fit._canonize_between_tids(tidi, tidj) + else: + # pseudo canonicalization + tn_fit._canonize_around_tids([tid]) + + # get the new conjugate tensor + ti_new = tn_hole.contract(output_inds=ti.inds, optimize="auto-hq") + ti_new.conj_() + # modify the data + ti.modify(data=ti_new.data) + + if tol != 0.0 or progbar: + # canonicalized form enable simpler distance computation + xAA = ti.norm() ** 2 # == xAB + d = 2 * abs(xBB - xAA) ** 0.5 / (xBB**0.5 + xAA**0.5) + if abs(d - old_d) < tol: + break + old_d = d + + if progbar: + pbar.set_description(f"{d:.3g}") + + return tn_fit.conj_() diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index 3d45c756..878ba4a9 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -79,6 +79,7 @@ tensor_network_distance, tensor_network_fit_als, tensor_network_fit_autodiff, + tensor_network_fit_tree, ) from .networking import ( compute_centralities, @@ -9074,10 +9075,17 @@ def fit( ---------- tn_target : TensorNetwork The target tensor network to try and fit the current one to. - method : {'als', 'autodiff'}, optional - Whether to use alternating least squares (ALS) or automatic - differentiation to perform the optimization. Generally ALS is - better for simple geometries, autodiff better for complex ones. + method : {'als', 'autodiff', 'tree'}, optional + How to perform the fitting. The options are: + + - 'als': alternating least squares (ALS) optimization, + - 'autodiff': automatic differentiation optimization, + - 'tree': ALS where the fitted tensor network has a tree structure + and thus a canonical form can be utilized for much greater + efficiency and stability. + + Generally ALS is better for simple geometries, autodiff better for + complex ones. Tree best if the tensor network has a tree structure. tol : float, optional The target norm distance. inplace : bool, optional @@ -9086,8 +9094,9 @@ def fit( Show a live progress bar of the fitting process. fitting_opts Supplied to either - :func:`~quimb.tensor.tensor_core.tensor_network_fit_als` or - :func:`~quimb.tensor.tensor_core.tensor_network_fit_autodiff`. + :func:`~quimb.tensor.tensor_core.tensor_network_fit_als`, + :func:`~quimb.tensor.tensor_core.tensor_network_fit_autodiff`, or + :func:`~quimb.tensor.tensor_core.tensor_network_fit_tree`. Returns ------- @@ -9097,9 +9106,9 @@ def fit( See Also -------- tensor_network_fit_als, tensor_network_fit_autodiff, - tensor_network_distance + tensor_network_fit_tree, tensor_network_distance, + tensor_network_1d_compress """ - check_opt("method", method, ("als", "autodiff")) fitting_opts["tol"] = tol fitting_opts["inplace"] = inplace fitting_opts["progbar"] = progbar @@ -9108,7 +9117,15 @@ def fit( if method == "autodiff": return tensor_network_fit_autodiff(self, tn_target, **fitting_opts) - return tensor_network_fit_als(self, tn_target, **fitting_opts) + elif method == "tree": + return tensor_network_fit_tree(self, tn_target, **fitting_opts) + elif method == "als": + return tensor_network_fit_als(self, tn_target, **fitting_opts) + else: + raise ValueError( + f"Unrecognized method {method}. Should be one of: " + "{'als', 'autodiff', 'tree'}." + ) fit_ = functools.partialmethod(fit, inplace=True) diff --git a/tests/test_tensor/test_fitting.py b/tests/test_tensor/test_fitting.py new file mode 100644 index 00000000..12295cc7 --- /dev/null +++ b/tests/test_tensor/test_fitting.py @@ -0,0 +1,122 @@ +import importlib + +import numpy as np +import pytest + +import quimb.tensor as qtn + +requires_autograd = pytest.mark.skipif( + importlib.util.find_spec("autograd") is None, + reason="autograd not installed", +) + + +@pytest.mark.parametrize("method", ("auto", "dense", "overlap")) +@pytest.mark.parametrize("normalized", ( + True, + False, + "squared", + "infidelity", + "infidelity_sqrt", +)) +def test_tensor_network_distance(method, normalized): + n = 6 + A = qtn.TN_rand_reg(n=n, reg=3, D=2, phys_dim=2, dtype=complex) + Ad = A.to_dense([f"k{i}" for i in range(n)]) + B = qtn.TN_rand_reg(n=6, reg=3, D=2, phys_dim=2, dtype=complex) + Bd = B.to_dense([f"k{i}" for i in range(n)]) + d1 = np.linalg.norm(Ad - Bd) + d2 = A.distance(B, method=method, normalized=normalized) + if normalized: + assert 0 <= d2 <= 2 + else: + assert d1 == pytest.approx(d2) + + +@pytest.mark.parametrize( + "method,opts", + ( + ("als", (("enforce_pos", False), ("solver", "lstsq"))), + ("als", (("enforce_pos", True),)), + ("tree", ()), + pytest.param( + "autodiff", + (("distance_method", "dense"),), + marks=requires_autograd, + ), + pytest.param( + "autodiff", + (("distance_method", "overlap"),), + marks=requires_autograd, + ), + ), +) +@pytest.mark.parametrize("dtype", ("float64", "complex128")) +def test_fit_mps(method, opts, dtype): + k1 = qtn.MPS_rand_state(5, 3, seed=666, dtype=dtype) + k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype) + assert k1.distance_normalized(k2) > 1e-3 + k1.fit_(k2, method=method, progbar=True, **dict(opts)) + assert k1.distance_normalized(k2) < 1e-3 + + +@pytest.mark.parametrize( + "method,opts", + ( + ("als", (("enforce_pos", False),)), + ("als", (("enforce_pos", True),)), + pytest.param( + "autodiff", + (("distance_method", "dense"),), + marks=requires_autograd, + ), + pytest.param( + "autodiff", + (("distance_method", "overlap"),), + marks=requires_autograd, + ), + ), +) +@pytest.mark.parametrize("dtype", ("float64", "complex128")) +def test_fit_rand_reg(method, opts, dtype): + r1 = qtn.TN_rand_reg(5, 4, D=2, seed=666, phys_dim=2, dtype=dtype) + k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype) + assert r1.distance(k2) > 1e-3 + r1.fit_(k2, method=method, progbar=True, **dict(opts)) + assert r1.distance(k2) < 1e-3 + + +@pytest.mark.parametrize( + "method,opts", + ( + ("als", (("enforce_pos", False),)), + ("als", (("enforce_pos", True),)), + ("tree", ()), + pytest.param( + "autodiff", + (("distance_method", "dense"),), + marks=requires_autograd, + ), + pytest.param( + "autodiff", + (("distance_method", "overlap"),), + marks=requires_autograd, + ), + ), +) +@pytest.mark.parametrize("dtype", ("float64", "complex128")) +def test_fit_partial_tags(method, opts, dtype): + k1 = qtn.MPS_rand_state(5, 3, seed=666, dtype=dtype) + k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype) + d0 = k1.distance(k2) + tags = ["I0", "I2", "I4"] + k1f = k1.fit( + k2, tol=1e-3, tags=tags, method=method, progbar=True, **dict(opts) + ) + assert k1f.distance(k2) < d0 + if method != "tree": + assert (k1f[0] - k1[0]).norm() > 1e-12 + assert (k1f[1] - k1[1]).norm() < 1e-12 + assert (k1f[2] - k1[2]).norm() > 1e-12 + assert (k1f[3] - k1[3]).norm() < 1e-12 + assert (k1f[4] - k1[4]).norm() > 1e-12 diff --git a/tests/test_tensor/test_tensor_core.py b/tests/test_tensor/test_tensor_core.py index bf6950cb..c377ce4c 100644 --- a/tests/test_tensor/test_tensor_core.py +++ b/tests/test_tensor/test_tensor_core.py @@ -1034,101 +1034,6 @@ def test_contract_to_dense_reduced_factor(self): Ur = A @ Rinv assert_allclose(Ur @ Ur.T, np.eye(4), atol=1e-10) - @pytest.mark.parametrize("method", ("auto", "dense", "overlap")) - @pytest.mark.parametrize("normalized", (True, False)) - def test_tensor_network_distance(self, method, normalized): - n = 6 - A = qtn.TN_rand_reg(n=n, reg=3, D=2, phys_dim=2, dtype=complex) - Ad = A.to_dense([f"k{i}" for i in range(n)]) - B = qtn.TN_rand_reg(n=6, reg=3, D=2, phys_dim=2, dtype=complex) - Bd = B.to_dense([f"k{i}" for i in range(n)]) - d1 = np.linalg.norm(Ad - Bd) - d2 = A.distance(B, method=method, normalized=normalized) - if normalized: - assert 0 <= d2 <= 2 - else: - assert d1 == pytest.approx(d2) - - @pytest.mark.parametrize( - "method,opts", - ( - ("als", (("enforce_pos", False), ("solver", "lstsq"))), - ("als", (("enforce_pos", True),)), - pytest.param( - "autodiff", - (("distance_method", "dense"),), - marks=requires_autograd, - ), - pytest.param( - "autodiff", - (("distance_method", "overlap"),), - marks=requires_autograd, - ), - ), - ) - def test_fit_mps(self, method, opts): - k1 = qtn.MPS_rand_state(5, 3, seed=666) - k2 = qtn.MPS_rand_state(5, 3, seed=667) - assert k1.distance_normalized(k2) > 1e-3 - k1.fit_(k2, method=method, progbar=True, **dict(opts)) - assert k1.distance_normalized(k2) < 1e-3 - - @pytest.mark.parametrize( - "method,opts", - ( - ("als", (("enforce_pos", False),)), - ("als", (("enforce_pos", True),)), - pytest.param( - "autodiff", - (("distance_method", "dense"),), - marks=requires_autograd, - ), - pytest.param( - "autodiff", - (("distance_method", "overlap"),), - marks=requires_autograd, - ), - ), - ) - def test_fit_rand_reg(self, method, opts): - r1 = qtn.TN_rand_reg(5, 4, D=2, seed=666, phys_dim=2) - k2 = qtn.MPS_rand_state(5, 3, seed=667) - assert r1.distance(k2) > 1e-3 - r1.fit_(k2, method=method, progbar=True, **dict(opts)) - assert r1.distance(k2) < 1e-3 - - @pytest.mark.parametrize( - "method,opts", - ( - ("als", (("enforce_pos", False),)), - ("als", (("enforce_pos", True),)), - pytest.param( - "autodiff", - (("distance_method", "dense"),), - marks=requires_autograd, - ), - pytest.param( - "autodiff", - (("distance_method", "overlap"),), - marks=requires_autograd, - ), - ), - ) - def test_fit_partial_tags(self, method, opts): - k1 = qtn.MPS_rand_state(5, 3, seed=666) - k2 = qtn.MPS_rand_state(5, 3, seed=667) - d0 = k1.distance(k2) - tags = ["I0", "I2", "I4"] - k1f = k1.fit( - k2, tol=1e-3, tags=tags, method=method, progbar=True, **dict(opts) - ) - assert k1f.distance(k2) < d0 - assert (k1f[0] - k1[0]).norm() > 1e-12 - assert (k1f[1] - k1[1]).norm() < 1e-12 - assert (k1f[2] - k1[2]).norm() > 1e-12 - assert (k1f[3] - k1[3]).norm() < 1e-12 - assert (k1f[4] - k1[4]).norm() > 1e-12 - def test_reindex(self): a = Tensor(np.random.randn(2, 3, 4), inds=[0, 1, 2], tags="red") b = Tensor(np.random.randn(3, 4, 5), inds=[1, 2, 3], tags="blue")