Skip to content

Commit

Permalink
update SU plot method, changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Nov 19, 2024
1 parent b90dbf1 commit 335860f
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 24 deletions.
3 changes: 2 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ Release notes for `quimb`.
- specialize [`CircuitMPS.local_expectation`](quimb.tensor.circuit.CircuitMPS.local_expectation) to make use of the MPS form.
- add [`PEPS.product_state`](quimb.tensor.tensor_2d.PEPS.product_state) for constructing a PEPS representing a product state.
- add [`PEPS.vacuum`](quimb.tensor.tensor_2d.PEPS.vacuum) for constructing a PEPS representing the vacuum state $|000\ldots0\rangle$.
- [tn.gauge_all_simple](quimb.tensor.tensor_core.TensorNetwork.gauge_all_simple): improve scheduling and add `damping` and `touched_tids` options.
- [`tn.gauge_all_simple`](quimb.tensor.tensor_core.TensorNetwork.gauge_all_simple): improve scheduling and add `damping` and `touched_tids` options.
- [`qtn.SimpleUpdateGen`](quimb.tensor.tensor_arbgeom_tebd.SimpleUpdateGen): add gauge difference update checking and `tol` and `equilibrate` settings. Update `.plot()` method. Default to a small `cutoff`.

---

Expand Down
130 changes: 107 additions & 23 deletions quimb/tensor/tensor_arbgeom_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,42 +789,126 @@ def compute_energy(self):

@default_to_neutral_style
def plot(
self, zoom="auto", xscale="symlog", xscale_linthresh=20, hlines=()
self,
zoom="auto",
xscale="symlog",
xscale_linthresh=20,
color_energy=(0.0, 0.5, 1.0),
color_gauge_diff=(1.0, 0.5, 0.0),
hlines=(),
figsize=(8, 4),
):
"""Plot an overview of the evolution of the energy and gauge diffs.
Parameters
----------
zoom : int or 'auto', optional
The number of iterations to zoom in on, or 'auto' to automatically
choose a reasonable zoom level.
xscale : {'linear', 'log', 'symlog'}, optional
The x-axis scale, for the upper plot of the entire evolution.
xscale_linthresh : float, optional
The linear threshold for the upper symlog scale.
color_energy : str or tuple, optional
The color to use for the energy plot.
color_gauge_diff : str or tuple, optional
The color to use for the gauge diff plot.
hlines : dict, optional
Add horizontal lines to the plot, with keys as labels and values
as the y-values.
figsize : tuple, optional
The size of the figure.
Returns
-------
fig, axs : matplotlib.Figure, tuple[matplotlib.Axes]
"""
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter
from matplotlib.colors import hsv_to_rgb

fig, ax = plt.subplots()
def set_axis_color(ax, which, color):
ax.spines[which].set_visible(True)
ax.spines[which].set_color(color)
ax.yaxis.label.set_color(color)
ax.tick_params(axis="y", colors=color, which="both")

xs = np.array(self.its)
ys = np.array(self.energies)
x_en = np.array(self.its)
y_en = np.array(self.energies)
x_gd = np.arange(1, len(self._gauge_diffs) + 1)
y_gd = np.array(self._gauge_diffs)

ax.plot(xs, ys, ".-")
ax.set_xlabel("Iteration")
ax.set_ylabel("Energy")
if zoom is not None:
if zoom == "auto":
zoom = min(200, self.n // 2)
nz = self.n - zoom

fig, axs = plt.subplots(nrows=2, figsize=figsize)

# plotted zoomed out
# energy
axl = axs[0]
axl.plot(x_en, y_en, marker="|", color=color_energy)
axl.set_xscale(xscale, linthresh=xscale_linthresh)
axl.set_ylabel("Energy")
axl.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
set_axis_color(axl, "left", color_energy)
# gauge diff
axr = axl.twinx()
axr.plot(
x_gd,
y_gd,
linestyle="--",
color=color_gauge_diff,
)
axr.set_ylabel("Max gauge diff")
axr.set_yscale("log")
set_axis_color(axr, "right", color_gauge_diff)

axl.axvline(
nz,
color=(0.5, 0.5, 0.5, 0.5),
linestyle="-",
linewidth=1,
)

if xscale == "symlog":
ax.set_xscale(xscale, linthresh=xscale_linthresh)
ax.axvline(xscale_linthresh, color=(0.5, 0.5, 0.5), ls="-", lw=0.5)
else:
ax.set_xscale(xscale)
# plotted zoomed in
# energy
iz = min(range(len(x_en)), key=lambda i: x_en[i] < nz)
axl = axs[1]
axl.plot(x_en[iz:], y_en[iz:], marker="|", color=color_energy)
axl.set_ylabel("Energy")
axl.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
set_axis_color(axl, "left", color_energy)
axl.set_xlabel("Iteration")
# gauge diff
iz = min(range(len(x_gd)), key=lambda i: x_gd[i] < nz)
axr = axl.twinx()
axr.plot(
x_gd[iz:],
y_gd[iz:],
linestyle="--",
color=color_gauge_diff,
)
axr.set_ylabel("Max gauge diff")
axr.set_yscale("log")
set_axis_color(axr, "right", color_gauge_diff)

if hlines:
hlines = dict(hlines)
for i, (label, value) in enumerate(hlines.items()):
color = hsv_to_rgb([(0.1 * i) % 1.0, 0.9, 0.9])
ax.axhline(value, color=color, ls="--", label=label)
ax.text(1, value, label, color=color, va="bottom", ha="left")

if zoom is not None:
if zoom == "auto":
zoom = min(50, ys.size // 2)

iax = ax.inset_axes([0.5, 0.5, 0.5, 0.5])
iax.plot(xs[-zoom:], ys[-zoom:], ".-")
color = hsv_to_rgb([(0.45 - (0.08 * i)) % 1.0, 0.7, 0.6])
axs[0].axhline(value, color=color, ls=":", label=label)
axs[1].axhline(value, color=color, ls=":", label=label)
axs[0].text(
1, value, label, color=color, va="bottom", ha="left"
)
axs[1].text(
nz, value, label, color=color, va="bottom", ha="left"
)

return fig, ax
return fig, axs

def __repr__(self):
s = "<{}(n={}, tau={}, D={})>"
Expand Down

0 comments on commit 335860f

Please sign in to comment.