Skip to content

Commit

Permalink
L2BP: add update="sequential"
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Nov 7, 2023
1 parent c8ea241 commit 028e75f
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 13 deletions.
3 changes: 3 additions & 0 deletions quimb/experimental/belief_propagation/l1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,16 @@ def _update_m(key, data):

if self.update == "parallel":
new_data = {}
# compute all new messages
while self.touched:
key = self.touched.pop()
new_data[key] = _compute_m(key)
# insert all new messages
for key, data in new_data.items():
_update_m(key, data)

elif self.update == "sequential":
# compute each new message and immediately re-insert it
while self.touched:
key = self.touched.pop()
data = _compute_m(key)
Expand Down
40 changes: 31 additions & 9 deletions quimb/experimental/belief_propagation/l2bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def __init__(
site_tags=None,
damping=0.0,
local_convergence=True,
update="parallel",
optimize="auto-hq",
**contract_opts,
):
self.backend = next(t.backend for t in tn)
self.damping = damping
self.local_convergence = local_convergence
self.update = update
self.optimize = optimize
self.contract_opts = contract_opts

Expand Down Expand Up @@ -126,10 +128,12 @@ def iterate(self, tol=5e-6):
)

ncheck = len(self.touched)
nconv = 0
max_mdiff = -1.0
new_touched = set()

new_data = {}
while self.touched:
i, j = self.touched.pop()
def _compute_m(key):
i, j = key
bix = self.edges[(i, j) if i < j else (j, i)]
cix = tuple(ix + "**" for ix in bix)
output_inds = cix + bix
Expand All @@ -145,12 +149,11 @@ def iterate(self, tol=5e-6):
)
tm_new.modify(apply=self._symmetrize)
tm_new.modify(apply=self._normalize)
# defer setting the data to do a parallel update
new_data[i, j] = tm_new.data
return tm_new.data

def _update_m(key, data):
nonlocal nconv, max_mdiff

nconv = 0
max_mdiff = -1.0
for key, data in new_data.items():
tm = self.messages[key]

if self.damping > 0.0:
Expand All @@ -160,13 +163,32 @@ def iterate(self, tol=5e-6):

if mdiff > tol:
# mark touching messages for update
self.touched.update(self.touch_map[key])
new_touched.update(self.touch_map[key])
else:
nconv += 1

max_mdiff = max(max_mdiff, mdiff)
tm.modify(data=data)

if self.update == "parallel":
new_data = {}
# compute all new messages
while self.touched:
key = self.touched.pop()
new_data[key] = _compute_m(key)
# insert all new messages
for key, data in new_data.items():
_update_m(key, data)

elif self.update == "sequential":
# compute each new message and immediately re-insert it
while self.touched:
key = self.touched.pop()
data = _compute_m(key)
_update_m(key, data)

self.touched = new_touched

return nconv, ncheck, max_mdiff

def contract(self, strip_exponent=False):
Expand Down
7 changes: 5 additions & 2 deletions tests/test_tensor/test_belief_propagation/test_l1bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ def test_contract_loopy_approx(dtype, damping):

@pytest.mark.parametrize("dtype", ["float32", "complex64"])
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_contract_double_loopy_approx(dtype, damping):
@pytest.mark.parametrize("update", ("parallel", "sequential"))
def test_contract_double_loopy_approx(dtype, damping, update):
peps = qtn.PEPS.rand(4, 3, 2, seed=42, dtype=dtype)
tn = peps.H & peps
Z_ex = tn.contract()
info = {}
Z_bp1 = contract_l1bp(tn, damping=damping, info=info, progbar=True)
Z_bp1 = contract_l1bp(
tn, damping=damping, update=update, info=info, progbar=True
)
assert info["converged"]
assert Z_bp1 == pytest.approx(Z_ex, rel=0.3)
# compare with 2-norm BP on the peps directly
Expand Down
10 changes: 8 additions & 2 deletions tests/test_tensor/test_belief_propagation/test_l2bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_contract_double_layer_tree_exact(dtype):

@pytest.mark.parametrize("dtype", ["float32", "complex64"])
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_compress_double_layer_loopy(dtype, damping):
@pytest.mark.parametrize("update", ["parallel", "sequential"])
def test_compress_double_layer_loopy(dtype, damping, update):
peps = qtn.PEPS.rand(3, 4, bond_dim=3, seed=42, dtype=dtype)
pepo = qtn.PEPO.rand(3, 4, bond_dim=2, seed=42, dtype=dtype)

Expand All @@ -85,7 +86,12 @@ def test_compress_double_layer_loopy(dtype, damping):
# compress using BP
info = {}
tn_bp = compress_l2bp(
tn_lazy, max_bond=3, damping=damping, info=info, progbar=True
tn_lazy,
max_bond=3,
damping=damping,
update=update,
info=info,
progbar=True,
)
assert info["converged"]
assert tn_bp.num_tensors == 12
Expand Down

0 comments on commit 028e75f

Please sign in to comment.