# SPDX-FileCopyrightText: 2025-2026 Julian Peil <julian.peil@tuwien.ac.at>
# SPDX-License-Identifier: MIT
#
# DGAmore — Multi-Orbital Ladder Dynamical Vertex Approximation (LDGA) &
# Eliashberg Equation Solver for Strongly Correlated Electron Systems
"""
All matplotlib plotting helpers. These functions produce the diagnostic and result figures of a run — local
self-energy / susceptibility checks, frequency-resolved four-point maps, momentum-space two-point maps (with
optional Fermi-surface markers), the superconducting gap function, and the analytically continued spectral
function along a high-symmetry path. Each routine saves and/or shows its figure. All plotting is gated behind
``config.output.do_plotting`` and ``comm.rank == 0`` by the callers.
"""
import os
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.interpolate import RegularGridInterpolator
from dgamore.brillouin_zone import KGrid
from dgamore.gap_function import GapFunction
from dgamore.local_four_point import LocalFourPoint
from dgamore.local_n_point import LocalNPoint
from dgamore.matsubara_frequencies import MFHelper
from dgamore.n_point_base import IAmNonLocal
[docs]
def add_afzb(ax=None, kx=None, ky=None, lw=1.0, marker=""):
"""
Draws the antiferromagnetic zone-boundary lines (and the BZ axes) onto an existing axis, and sets its limits and
labels.
:param ax: The matplotlib axis to draw on.
:param kx: The kx grid values.
:param ky: The ky grid values.
:param lw: Line width of the drawn lines.
:param marker: Marker style for the drawn lines.
:return: None.
"""
if np.any(kx < 0):
ax.plot(np.linspace(-np.pi, 0, 101), np.linspace(0, np.pi, 101), "--k", lw=lw, marker=marker)
ax.plot(np.linspace(-np.pi, 0, 101), np.linspace(0, -np.pi, 101), "--k", lw=lw, marker=marker)
ax.plot(np.linspace(0, np.pi, 101), np.linspace(-np.pi, 0, 101), "--k", lw=lw, marker=marker)
ax.plot(np.linspace(0, np.pi, 101), np.linspace(np.pi, 0, 101), "--k", lw=lw, marker=marker)
ax.plot(kx, 0 * kx, "-k", lw=lw, marker=marker)
ax.plot(0 * ky, ky, "-k", lw=lw, marker=marker)
else:
ax.plot(np.linspace(0, np.pi, 101), np.linspace(np.pi, 2 * np.pi, 101), "--k", lw=lw, marker=marker)
ax.plot(np.linspace(np.pi, 0, 101), np.linspace(0, np.pi, 101), "--k", lw=lw, marker=marker)
ax.plot(np.linspace(np.pi, 2 * np.pi, 101), np.linspace(0, np.pi, 101), "--k", lw=lw, marker=marker)
ax.plot(np.linspace(np.pi, 2 * np.pi, 101), np.linspace(np.pi * 2, np.pi, 101), "--k", lw=lw, marker=marker)
ax.plot(kx, np.pi * np.ones_like(kx), "-k", lw=lw, marker=marker)
ax.plot(np.pi * np.ones_like(ky), ky, "-k", lw=lw, marker=marker)
ax.set_xlim(kx[0], kx[-1])
ax.set_ylim(ky[0], ky[-1])
ax.set_xlabel("$k_x$")
ax.set_ylabel("$k_y$")
[docs]
def find_zeros(mat: np.ndarray) -> np.ndarray:
"""
Finds the zero crossings (zero contour) of a real 2D field via a matplotlib contour at level 0.
:param mat: The 2D array whose zero contour is sought (the real part is used).
:return: The integer index coordinates of the zero-contour vertices.
"""
ind_x = np.arange(mat.shape[0])
ind_y = np.arange(mat.shape[1])
cs1 = plt.contour(ind_x, ind_y, mat.T.real, cmap="magma", levels=[0])
paths = cs1.get_paths()
plt.close()
paths = np.atleast_1d(paths)
vertices = []
for path in paths:
vertices.extend(path.vertices)
return np.array(vertices, dtype=int)
[docs]
def sigma_loc_checks(
siw_arr: list[np.ndarray],
labels: list[str],
beta: float,
output_dir: str = "./",
show: bool = False,
save: bool = True,
name: str = "",
xmax: float = 0,
) -> None:
r"""
Produces the routine local self-energy diagnostic plots (real/imaginary part, linear and log-log) for a set of
self-energies.
:param siw_arr: List of local self-energy arrays (one fermionic frequency axis each).
:param labels: Plot labels, one per self-energy.
:param beta: Inverse temperature :math:`\beta` (sets the default x-axis range).
:param output_dir: Directory to save the figure to.
:param show: Whether to display the figure.
:param save: Whether to save the figure.
:param name: Name tag used in the output filename.
:param xmax: Maximum frequency on the x-axis (defaults to ``5 + 2*beta`` if 0).
:return: None.
"""
if xmax == 0:
xmax = 5 + 2 * beta
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(8, 5))
axes = axes.flatten()
for i, siw in enumerate(siw_arr):
vn = MFHelper.vn(np.size(siw) // 2)
axes[0].plot(vn, siw.real, label=labels[i])
axes[1].plot(vn, siw.imag, label=labels[i])
axes[2].loglog(vn, siw.real, label=labels[i])
axes[3].loglog(vn, np.abs(siw.imag), label=labels[i])
for i in range(4):
axes[i].set_xlabel(r"$\nu_n$")
axes[0].set_ylabel(r"$\Re \Sigma(i\nu_n)$")
axes[1].set_ylabel(r"$\Im \Sigma(i\nu_n)$")
axes[2].set_ylabel(r"$\Re \Sigma(i\nu_n)$")
axes[3].set_ylabel(r"$|\Im \Sigma(i\nu_n)|$")
axes[0].set_xlim(0, xmax)
axes[1].set_xlim(0, xmax)
axes[2].set_xlim(None, xmax)
axes[3].set_xlim(None, xmax)
plt.legend()
axes[1].set_ylim(None, 0)
plt.tight_layout()
if save:
plt.savefig(os.path.join(output_dir, f"sde_{name}_check.pdf"), bbox_inches="tight", pad_inches=0.05)
if show:
plt.show()
else:
plt.close()
[docs]
def chi_checks(
chi_dens_list: list[np.ndarray],
chi_magn_list: list[np.ndarray],
beta: float,
labels: list[str],
e_kin: float,
output_dir: str = "./",
orbs=[0, 0, 0, 0],
show: bool = False,
save: bool = True,
name: str = "",
):
r"""
Produces the routine diagnostic plots for the density and magnetic susceptibilities (linear and log-log, with the
:math:`1/\omega^2` kinetic-energy asymptote overlaid).
:param chi_dens_list: List of density susceptibility arrays.
:param chi_magn_list: List of magnetic susceptibility arrays.
:param beta: Inverse temperature :math:`\beta`.
:param labels: Plot labels, one per susceptibility.
:param e_kin: Kinetic energy, used to draw the high-frequency asymptote.
:param output_dir: Directory to save the figure to.
:param orbs: The four orbital indices to plot.
:param show: Whether to display the figure.
:param save: Whether to save the figure.
:param name: Name tag used in the output filename.
:return: None.
"""
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(8, 5), dpi=500)
axes = axes.flatten()
niw_chi_input = np.size(chi_dens_list[0][*orbs, :])
for i, cd in enumerate(chi_dens_list):
axes[0].plot(MFHelper.wn(len(cd[*orbs, :]) // 2), cd[*orbs, :].real, label=labels[i])
axes[0].set_ylabel(r"$\Re \chi(i\omega_n)_{dens}$")
axes[0].legend()
for i, cd in enumerate(chi_magn_list):
axes[1].plot(MFHelper.wn(len(cd[*orbs, :]) // 2), cd[*orbs, :].real, label=labels[i])
axes[1].set_ylabel(r"$\Re \chi(i\omega_n)_{magn}$")
axes[1].legend()
for i, cd in enumerate(chi_dens_list):
axes[2].loglog(MFHelper.wn(len(cd[*orbs, :]) // 2), cd[*orbs, :].real, label=labels[i], ms=0)
axes[2].loglog(
MFHelper.wn(niw_chi_input),
np.real(1 / (1j * MFHelper.wn(niw_chi_input, beta) + 0.000001) ** 2 * e_kin) * 2,
ls="--",
label="Asympt",
ms=0,
)
axes[2].set_ylabel(r"$\Re \chi(i\omega_n)_{dens}$")
axes[2].legend()
for i, cd in enumerate(chi_magn_list):
axes[3].loglog(MFHelper.wn(len(cd[*orbs, :]) // 2), cd[*orbs, :].real, label=labels[i], ms=0)
axes[3].loglog(
MFHelper.wn(niw_chi_input),
np.real(1 / (1j * MFHelper.wn(niw_chi_input, beta) + 0.000001) ** 2 * e_kin) * 2,
"--",
label="Asympt",
ms=0,
)
axes[3].set_ylabel(r"$\Re \chi(i\omega_n)_{magn}$")
axes[3].legend()
axes[0].set_xlim(-1, 10)
axes[1].set_xlim(-1, 10)
plt.tight_layout()
if save:
plt.savefig(os.path.join(output_dir, f"chi_dens_magn_{name}.pdf"), bbox_inches="tight", pad_inches=0.05)
if show:
plt.show()
else:
plt.close()
[docs]
def plot_nu_nup(
obj: LocalFourPoint,
orbs: np.ndarray | list | tuple = (0, 0, 0, 0),
omega: int = 0,
do_save: bool = True,
output_dir: str = "./",
name: str = "Name",
colormap: str = "RdBu",
show: bool = False,
) -> None:
r"""
Plots the real and imaginary parts of a local four-point object in the :math:`(\nu, \nu')` plane for fixed
orbitals and bosonic frequency :math:`\omega`.
:param obj: The :class:`LocalFourPoint` to plot.
:param orbs: The four orbital indices to select.
:param omega: The bosonic frequency index to plot.
:param do_save: Whether to save the figure.
:param output_dir: Directory to save the figure to.
:param name: Figure title and output filename tag.
:param colormap: The matplotlib colormap.
:param show: Whether to display the figure.
:return: None.
:raises ValueError: If ``omega`` is out of range or ``orbs`` does not have four entries.
"""
if np.abs(omega) > obj.niw:
raise ValueError(f"Omega {omega} out of range.")
if len(orbs) != 4:
raise ValueError("'orbs' needs to be of size 4.")
fig, axes = plt.subplots(ncols=2, figsize=(7, 3), dpi=251)
axes = axes.flatten()
wn_list = MFHelper.wn(obj.niw)
wn_index = np.argmax(wn_list == omega)
mat = obj.mat[orbs[0], orbs[1], orbs[2], orbs[3], wn_index, ...]
vn = MFHelper.vn(obj.niv)
im1 = axes[0].pcolormesh(vn, vn, mat.real, cmap=colormap)
im2 = axes[1].pcolormesh(vn, vn, mat.imag, cmap=colormap)
axes[0].set_title(r"$\Re$")
axes[1].set_title(r"$\Im$")
for ax in axes:
ax.set_xlabel(r"$\nu_p$")
ax.set_ylabel(r"$\nu$")
ax.set_aspect("equal")
fig.suptitle(name)
fig.colorbar(im1, ax=(axes[0]), aspect=15, fraction=0.08, location="right", pad=0.05)
fig.colorbar(im2, ax=(axes[1]), aspect=15, fraction=0.08, location="right", pad=0.05)
plt.tight_layout()
if do_save:
plt.savefig(os.path.join(output_dir, f"{name}_w{omega}.pdf"), bbox_inches="tight", pad_inches=0.05)
if show:
plt.show()
else:
plt.close()
[docs]
def plot_two_point_kx_ky(
obj: LocalNPoint | IAmNonLocal,
kx: np.ndarray,
ky: np.ndarray,
pi_shift: bool = True,
title: str = "",
name: str = "",
orbs: np.ndarray | list | tuple = (0, 0),
output_dir="./",
cmap="magma",
scatter=None,
save: bool = True,
show: bool = False,
):
r"""
Plots the real and imaginary parts of a two-point function in the :math:`(k_x, k_y)` plane (at :math:`k_z = 0`
and the first positive Matsubara frequency) for fixed orbitals, with the antiferromagnetic zone boundary and
optional scatter points overlaid.
:param obj: The two-point object to plot (a :class:`LocalNPoint` / :class:`IAmNonLocal`).
:param kx: The kx grid values for the plot axes.
:param ky: The ky grid values for the plot axes.
:param pi_shift: Whether to shift the momentum grid by :math:`\pi` before plotting.
:param title: Title suffix for the subplots.
:param name: Output filename tag.
:param orbs: The two orbital indices to select.
:param output_dir: Directory to save the figure to.
:param cmap: The matplotlib colormap.
:param scatter: Optional ``[N, 2]`` array of points to scatter on the plot (e.g. Fermi-surface points).
:param save: Whether to save the figure.
:param show: Whether to display the figure.
:return: None.
:raises ValueError: If ``orbs`` does not have two entries.
"""
if len(orbs) != 2:
raise ValueError("'orbs' needs to be of size 2.")
mat = obj.shift_k_by_pi().mat if pi_shift else obj.mat
niv = mat.shape[-1] // 2
mat = mat[:, :, 0, orbs[0], orbs[1], niv]
mat = np.concatenate([mat, mat[0:1, :, ...]], axis=0)
mat = np.concatenate([mat, mat[:, 0:1, ...]], axis=1)
fig, axes = plt.subplots(ncols=2, figsize=(7, 3), dpi=500)
axes = axes.flatten()
im1 = axes[0].pcolormesh(kx, ky, mat.T.real, cmap=cmap)
im2 = axes[1].pcolormesh(kx, ky, mat.T.imag, cmap=cmap)
axes[0].set_title(r"$\Re$" + title)
axes[1].set_title(r"$\Im$" + title)
tick_vals = [-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]
tick_labels = [r"$-\pi$", r"$-\frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]
for ax in axes:
ax.set_xlabel(r"$k_x$")
ax.set_aspect("equal")
add_afzb(ax=ax, kx=kx, ky=ky, lw=1.0, marker="")
ax.set_xticks(tick_vals)
ax.set_xticklabels(tick_labels)
ax.set_yticks(tick_vals)
axes[0].set_ylabel(r"$k_y$")
axes[1].set_ylabel("")
axes[0].set_yticklabels(tick_labels)
axes[1].set_yticklabels([""] * len(tick_labels))
divider1 = make_axes_locatable(axes[0])
cax1 = divider1.append_axes("right", size="5%", pad=0.1) # 5% width, 10% padding
divider2 = make_axes_locatable(axes[1])
cax2 = divider2.append_axes("right", size="5%", pad=0.1) # 5% width, 10% padding
# fig.suptitle(title)
vmin1, vmax1 = im1.get_clim()
vmin2, vmax2 = im2.get_clim()
cb1 = fig.colorbar(im1, cax=cax1, ticks=np.linspace(vmin1, vmax1, 5))
cb1.locator = ticker.MaxNLocator(nbins=5)
cb1.update_ticks()
cb2 = fig.colorbar(im2, cax=cax2, ticks=np.linspace(vmin2, vmax2, 5))
cb2.locator = ticker.MaxNLocator(nbins=5)
cb2.update_ticks()
if scatter is not None:
for ax in axes:
colours = plt.cm.get_cmap(cmap)(np.linspace(0, 1, np.shape(scatter)[0]))
ax.scatter(scatter[:, 0], scatter[:, 1], marker="o", c=colours)
plt.tight_layout()
if save:
plt.savefig(os.path.join(output_dir, f"{name}.pdf"), bbox_inches="tight", pad_inches=0.05)
if show:
plt.show()
else:
plt.close()
[docs]
def plot_two_point_kx_ky_real_and_imag(
obj: LocalNPoint | IAmNonLocal,
kx: np.ndarray,
ky: np.ndarray,
pi_shift: bool = True,
title: str = "",
name: str = "",
orbs: np.ndarray | list | tuple = (0, 0),
output_dir="./",
cmap="magma",
save: bool = True,
show: bool = False,
):
r"""
Plots a two-point function in the :math:`(k_x, k_y)` plane for fixed orbitals, writing the real and imaginary
parts to two separate files.
:param obj: The two-point object to plot (a :class:`LocalNPoint` / :class:`IAmNonLocal`).
:param kx: The kx grid values for the plot axes.
:param ky: The ky grid values for the plot axes.
:param pi_shift: Whether to shift the momentum grid by :math:`\pi` before plotting.
:param title: Title (rendered inside the math mode of the subplot titles).
:param name: Output filename tag (``_real``/``_imag`` is appended).
:param orbs: The two orbital indices to select.
:param output_dir: Directory to save the figures to.
:param cmap: The matplotlib colormap.
:param save: Whether to save the figures.
:param show: Whether to display the figures.
:return: None.
:raises ValueError: If ``orbs`` does not have two entries.
"""
cm = 1.0 / 2.54
if len(orbs) != 2:
raise ValueError("'orbs' needs to be of size 2.")
obj = obj.shift_k_by_pi() if pi_shift else obj
for idx, mat in enumerate([obj.mat.real, obj.mat.imag]):
niv = mat.shape[-1] // 2
mat = mat[:, :, 0, orbs[0], orbs[1], niv]
mat = np.concatenate([mat, mat[0:1, :, ...]], axis=0)
mat = np.concatenate([mat, mat[:, 0:1, ...]], axis=1)
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(9 * cm, 9 * cm), dpi=500)
im = axes.pcolormesh(kx, ky, mat.T, cmap=cmap)
axes.set_title(rf"$\Re {title}$" if idx == 0 else rf"$\Im {title}$")
tick_vals = [-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]
tick_labels = [r"$-\pi$", r"$-\frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]
axes.set_xlabel(r"$k_x$")
axes.set_ylabel(r"$k_y$")
axes.set_aspect("equal")
axes.set_xticks(tick_vals)
axes.set_xticklabels(tick_labels)
axes.set_yticks(tick_vals)
axes.set_yticklabels(tick_labels)
divider1 = make_axes_locatable(axes)
cax1 = divider1.append_axes("right", size="5%", pad=0.1) # 5% width, 10% padding
cb = fig.colorbar(im, cax=cax1)
cb.locator = ticker.MaxNLocator(nbins=5)
cb.update_ticks()
plt.tight_layout()
if save:
plt.savefig(
os.path.join(output_dir, f"{name}_{"real" if idx == 0 else "imag"}.pdf"),
bbox_inches="tight",
pad_inches=0.05,
)
if show:
plt.show()
else:
plt.close()
[docs]
def plot_two_point_kx_ky_with_fs_points(
obj: LocalNPoint | IAmNonLocal,
k_grid: KGrid,
kx: np.ndarray,
ky: np.ndarray,
pi_shift: bool = True,
title: str = "",
name: str = "",
orbs: np.ndarray | list | tuple = (0, 0),
output_dir="./",
cmap="magma",
do_save: bool = True,
show: bool = False,
):
r"""
Plots a two-point function in the :math:`(k_x, k_y)` plane for fixed orbitals, with the Fermi-surface points
(zero crossings of the quantity in the reduced BZ quadrant) scattered on top (see :func:`plot_two_point_kx_ky`).
:param obj: The two-point object to plot (a :class:`LocalNPoint` / :class:`IAmNonLocal`).
:param k_grid: The :class:`KGrid` providing the k-axis values for the Fermi-surface points.
:param kx: The kx grid values for the plot axes.
:param ky: The ky grid values for the plot axes.
:param pi_shift: Whether to shift the momentum grid by :math:`\pi` before plotting.
:param title: Title suffix for the subplots.
:param name: Output filename tag.
:param orbs: The two orbital indices to select.
:param output_dir: Directory to save the figure to.
:param cmap: The matplotlib colormap.
:param do_save: Whether to save the figure.
:param show: Whether to display the figure.
:return: None.
"""
mat = obj.mat[..., 0, orbs[0], orbs[1], obj.niv][: obj.nq[0] // 2, : obj.nq[1] // 2]
fs_ind = find_zeros(mat)
n_fs = np.shape(fs_ind)[0]
fs_ind = fs_ind[: n_fs // 2]
fs_points = np.stack((k_grid.kx[fs_ind[:, 0]], k_grid.ky[fs_ind[:, 1]]), axis=1)
plot_two_point_kx_ky(obj, kx, ky, pi_shift, title, name, orbs, output_dir, cmap, fs_points, do_save, show)
[docs]
def plot_gap_function(
obj: GapFunction,
kx: np.ndarray,
ky: np.ndarray,
name: str = "",
orbs: np.ndarray | list | tuple = (0, 0),
output_dir="./",
cmap="magma",
scatter=None,
do_save: bool = True,
show: bool = False,
):
r"""
Plots the gap function in the :math:`(k_x, k_y)` plane for fixed orbitals. Rather than the real/imaginary parts,
it shows the values at the smallest positive and smallest negative fermionic Matsubara frequency, which makes the
frequency parity (and hence the gap symmetry) visible.
:param obj: The :class:`GapFunction` to plot.
:param kx: The kx grid values for the plot axes.
:param ky: The ky grid values for the plot axes.
:param name: Output filename tag.
:param orbs: The two orbital indices to select.
:param output_dir: Directory to save the figure to.
:param cmap: The matplotlib colormap.
:param scatter: Optional ``[N, 2]`` array of points to scatter on the plot.
:param do_save: Whether to save the figure.
:param show: Whether to display the figure.
:return: None.
:raises ValueError: If ``orbs`` does not have two entries.
"""
if len(orbs) != 2:
raise ValueError("'orbs' needs to be of size 2.")
gap = obj.shift_k_by_pi().mat
niv_pp = gap.shape[-1] // 2
gap = gap[:, :, 0, orbs[0], orbs[1], niv_pp - 1 : niv_pp + 1]
gap = np.concatenate([gap, gap[0:1, :, ...]], axis=0)
gap = np.concatenate([gap, gap[:, 0:1, ...]], axis=1)
fig, axes = plt.subplots(ncols=2, figsize=(7, 3), dpi=500)
axes = axes.flatten()
im1 = axes[0].pcolormesh(kx, ky, gap[..., 0].T.real, cmap=cmap)
im2 = axes[1].pcolormesh(kx, ky, gap[..., 1].T.real, cmap=cmap)
axes[0].set_title(f"$\\Delta^{{k_x k_y k_z=0;\\nu=\\frac{{\\pi}}{{\\beta}}}}_{{{obj.channel.value[0]}}}$")
axes[1].set_title(f"$\\Delta^{{k_x k_y k_z=0;\\nu=-\\frac{{\\pi}}{{\\beta}}}}_{{{obj.channel.value[0]}}}$")
tick_vals = [-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]
tick_labels = [r"$-\pi$", r"$-\frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]
for ax in axes:
ax.set_xlabel(r"$k_x$")
ax.set_aspect("equal")
add_afzb(ax=ax, kx=kx, ky=ky, lw=1.0, marker="")
ax.set_xticks(tick_vals)
ax.set_xticklabels(tick_labels)
ax.set_yticks(tick_vals)
axes[0].set_ylabel(r"$k_y$")
axes[1].set_ylabel("")
axes[0].set_yticklabels(tick_labels)
axes[1].set_yticklabels([""] * len(tick_labels))
divider1 = make_axes_locatable(axes[0])
cax1 = divider1.append_axes("right", size="5%", pad=0.1) # 5% width, 10% padding
divider2 = make_axes_locatable(axes[1])
cax2 = divider2.append_axes("right", size="5%", pad=0.1) # 5% width, 10% padding
# fig.suptitle(title)
vmin1, vmax1 = im1.get_clim()
vmin2, vmax2 = im2.get_clim()
fig.colorbar(im1, cax=cax1, ticks=np.linspace(vmin1, vmax1, 5))
fig.colorbar(im2, cax=cax2, ticks=np.linspace(vmin2, vmax2, 5))
if scatter is not None:
for ax in axes:
colours = plt.cm.get_cmap(cmap)(np.linspace(0, 1, np.shape(scatter)[0]))
ax.scatter(scatter[:, 0], scatter[:, 1], marker="o", c=colours)
plt.tight_layout()
if do_save:
plt.savefig(os.path.join(output_dir, f"{name}.pdf"), bbox_inches="tight", pad_inches=0.05)
if show:
plt.show()
else:
plt.close()
[docs]
def plot_spectrum(
a_w: np.ndarray,
kx: np.ndarray,
ky: np.ndarray,
kz: np.ndarray,
high_sym_points: list[tuple[float, float, float, str]],
energy_window: tuple[float, float],
beta: float,
title: str,
fermi_energy: float = 0,
output_dir="./",
name: str = "",
cmap="magma",
do_save: bool = True,
show: bool = False,
):
r"""
Plots the total (band-summed) spectral function :math:`A(\mathbf{k}, \omega)` along a high-symmetry k-path,
interpolating the BZ-gridded data onto the path. The spectral function is expected in the band-diagonal basis.
:param a_w: The spectral function of shape ``[kx, ky, kz, n_bands, w]``.
:param kx: The kx grid values.
:param ky: The ky grid values.
:param kz: The kz grid values.
:param high_sym_points: The path corner points as ``(kx, ky, kz, label)`` tuples (fractional coordinates).
:param energy_window: The real-frequency window ``(w_min, w_max)`` for the y-axis.
:param beta: Inverse temperature :math:`\beta` (sets the real-frequency axis mapping).
:param title: The plot title.
:param fermi_energy: Energy offset subtracted so the Fermi level sits at zero.
:param output_dir: Directory to save the figure to.
:param name: Output filename tag.
:param cmap: The matplotlib colormap.
:param do_save: Whether to save the figure.
:param show: Whether to display the figure.
:return: None.
"""
n_per_seg = 200
# Determine Grid Properties for wrapping
k_axes = (kx, ky, kz)
a_w = np.sum(a_w, axis=-2) # sum over bands to get total spectral function
periods = []
for ax in k_axes:
if len(ax) > 1:
step = ax[1] - ax[0]
periods.append(ax.max() - ax.min() + step)
else:
periods.append(2 * np.pi)
periods = np.array(periods)
k_mins = np.array([ax.min() for ax in k_axes])
path_segments = []
labels = [rf"$\Gamma$" if "gamma" in p[3].lower() else rf"${p[3]}$" for p in high_sym_points]
points = np.array([p[:3] for p in high_sym_points])
points *= periods
for i in range(len(points) - 1):
start, end = points[i], points[i + 1]
# Linear interpolation between high-symmetry points
seg = np.linspace(start, end, n_per_seg, endpoint=(i == len(points) - 2))
path_segments.append(seg)
full_path = np.vstack(path_segments)
# Wrapping & Interpolation
path_wrapped = k_mins + (full_path - k_mins) % periods
nw = a_w.shape[-1]
a_on_path = np.zeros((len(full_path), nw))
# Create interpolator for each energy slice
for i in range(nw):
interp = RegularGridInterpolator(k_axes, a_w[..., i], bounds_error=False, fill_value=None)
a_on_path[:, i] = interp(path_wrapped)
a_on_path = np.nan_to_num(a_on_path)
# Calculate X-axis (Arc Length)
dists = np.sqrt(np.sum(np.diff(full_path, axis=0) ** 2, axis=1))
arc = np.concatenate(([0], np.cumsum(dists)))
tick_indices = [min(i * n_per_seg, len(arc) - 1) for i in range(len(labels))]
tick_x = arc[tick_indices]
# 5. Plotting
fig, ax = plt.subplots(figsize=(8, 5))
v_max = np.percentile(a_on_path, 98)
w_axis = 5 * beta * np.tan(np.linspace(-np.pi / 2.1, np.pi / 2.1, num=nw, endpoint=True)) / np.tan(np.pi / 2.1)
pm = ax.pcolormesh(
arc, w_axis - fermi_energy, a_on_path.T, cmap=cmap, shading="gouraud", rasterized=True, vmin=0, vmax=v_max
)
# Formatting
ax.set_xticks(tick_x)
ax.set_xticklabels(labels)
ax.set_xlim(arc[0], arc[-1])
ax.set_ylim(*energy_window)
ax.set_ylabel(r"$\omega - \varepsilon_{\mathrm{F}}$ [eV]")
ax.set_title(title)
# Guidelines
for tx in tick_x:
ax.axvline(tx, color="white", alpha=0.3, lw=0.8)
ax.axhline(0, color="white", lw=1.2, ls="--", alpha=0.6)
cbar = fig.colorbar(pm, pad=0.02)
cbar.set_label(r"$A(\mathbf{k}, \omega)$")
plt.tight_layout()
if do_save:
plt.savefig(os.path.join(output_dir, f"spectrum_{name}.png"), dpi=400)
if show:
plt.show()