# coding: utf-8
#
# This code is part of lattpy.
#
# Copyright (c) 2022, Dylan Jones
#
# This code is licensed under the MIT License. The copyright notice in the
# LICENSE file in the root directory and this permission notice shall
# be included in all copies or substantial portions of the Software.
"""Tools for dispersion computation and plotting."""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as tck
from .utils import chain
from .spatial import distance
from .plotting import draw_lines
from .atom import Atom
__all__ = [
"bandpath_subplots",
"plot_dispersion",
"disp_dos_subplots",
"plot_disp_dos",
"plot_bands",
"DispersionPath",
]
def _color_list(color, num_bands):
if color is None:
colors = [f"C{i}" for i in range(num_bands)]
elif isinstance(color, str) or not hasattr(color, "__len__"):
colors = [color] * num_bands
else:
colors = color
return colors
def _scale_xaxis(num_points, disp, scales=None):
sect_size = len(disp) / (num_points - 1)
scales = np.ones(num_points - 1) if scales is None else scales
k0, k, ticks = 0, [], [0]
for scale in scales:
k.extend(k0 + np.arange(sect_size) * scale)
k0 = k[-1]
ticks.append(k0)
return k, ticks
def _set_piticks(axis, num_ticks=2, frmt=".1f"):
axis.set_major_formatter(tck.FormatStrFormatter(rf"%{frmt} $\pi$"))
axis.set_major_locator(tck.LinearLocator(2 * num_ticks + 1))
[docs]def bandpath_subplots(ticks, labels, xlabel="$k$", ylabel="$E(k)$", grid="both"):
fig, ax = plt.subplots()
ax.set_xlim(0, ticks[-1])
ax.set_xticks(ticks)
ax.set_xticklabels(labels)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if grid:
if not isinstance(grid, str):
grid = "both"
ax.set_axisbelow(True)
ax.grid(b=True, which="major", axis=grid)
return fig, ax
def _draw_dispersion(ax, k, disp, color=None, fill=False, alpha=0.2, lw=1.0):
x = [0, np.max(k)]
colors = _color_list(color, disp.shape[1])
for i, band in enumerate(disp.T):
col = colors[i]
if isinstance(col, Atom):
col = col.color
ax.plot(k, band, lw=lw, color=col)
if fill:
ax.fill_between(x, min(band), max(band), color=col, alpha=alpha)
[docs]def plot_dispersion(
disp,
labels,
xlabel="$k$",
ylabel="$E(k)$",
grid="both",
color=None,
alpha=0.2,
lw=1.0,
scales=None,
fill=False,
ax=None,
show=True,
):
num_points = len(labels)
k, ticks = _scale_xaxis(num_points, disp, scales)
if ax is None:
fig, ax = bandpath_subplots(ticks, labels, xlabel, ylabel, grid)
else:
fig = ax.get_figure()
x = [0, np.max(k)]
colors = _color_list(color, disp.shape[1])
for i, band in enumerate(disp.T):
col = colors[i]
if isinstance(col, Atom):
col = col.color
ax.plot(k, band, lw=lw, color=col)
if fill:
ax.fill_between(x, min(band), max(band), color=col, alpha=alpha)
fig.tight_layout()
if show:
plt.show()
return ax
[docs]def disp_dos_subplots(
ticks,
labels,
xlabel="$k$",
ylabel="$E(k)$",
doslabel="$n(E)$",
wratio=(3, 1),
grid="both",
):
fig, axs = plt.subplots(1, 2, gridspec_kw={"width_ratios": wratio}, sharey="all")
ax1, ax2 = axs
ax1.set_xlim(0, ticks[-1])
if xlabel:
ax1.set_xlabel(xlabel)
if ylabel:
ax1.set_ylabel(ylabel)
if doslabel:
ax2.set_xlabel(doslabel)
ax1.set_xticks(ticks)
ax1.set_xticklabels(labels)
ax2.set_xticks([0])
if grid:
ax1.set_axisbelow(True)
ax1.grid(b=True, which="major", axis=grid)
ax2.set_axisbelow(True)
ax2.grid(b=True, which="major", axis=grid)
return fig, axs
[docs]def plot_disp_dos(
disp,
dos_data,
labels,
xlabel="k",
ylabel="E(k)",
doslabel="n(E)",
wratio=(3, 1),
grid="both",
color=None,
fill=True,
disp_alpha=0.2,
dos_alpha=0.2,
lw=1.0,
scales=None,
axs=None,
show=True,
):
num_points = len(labels)
k, ticks = _scale_xaxis(num_points, disp, scales)
if axs is None:
fig, axs = disp_dos_subplots(
ticks, labels, xlabel, ylabel, doslabel, wratio, grid
)
ax1, ax2 = axs
else:
ax1, ax2 = axs
fig = ax1.get_figure()
x = [0, np.max(k)]
colors = _color_list(color, disp.shape[1])
for i, band in enumerate(disp.T):
col = colors[i]
if isinstance(col, Atom):
col = col.color
ax1.plot(k, band, lw=lw, color=col)
if fill:
ax1.fill_between(x, min(band), max(band), color=col, alpha=disp_alpha)
for i, band in enumerate(dos_data):
col = colors[i]
if isinstance(col, Atom):
col = col.color
bins, dos = band
ax2.plot(dos, bins, lw=lw, color=col)
ax2.fill_betweenx(bins, 0, dos, alpha=dos_alpha, color=col)
ax2.set_xlim(0, ax2.get_xlim()[1])
fig.tight_layout()
if show:
plt.show()
return axs
[docs]def plot_bands(
kgrid,
bands,
k_label="k",
disp_label="E(k)",
grid="both",
contour_grid=False,
bz=None,
pi_ticks=True,
ax=None,
show=True,
):
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
dim = len(bands.shape) - 1
if dim == 1:
k = kgrid[0]
ax.plot(k, bands.T)
if k_label:
ax.set_xlabel(f"${k_label}$")
if disp_label:
ax.set_ylabel(f"${disp_label}$")
if grid:
ax.grid(axis=grid)
if pi_ticks:
_set_piticks(ax.xaxis, num_ticks=2)
ax.set_xlim(np.min(k), np.max(k))
if bz is not None:
for x in bz:
ax.axvline(x=x, color="k")
elif dim == 2:
kx, ky = kgrid
kxx, kyy = np.meshgrid(kx, ky)
if len(bands) == 1:
bands = bands[0]
else:
bands = np.sum(np.abs(bands), axis=0)
im = ax.contourf(kxx, kyy, bands)
ax.set_aspect("equal")
if k_label:
ax.set_xlabel(f"{k_label}$_x$")
ax.set_ylabel(f"{k_label}$_y$")
if disp_label:
label = ""
if disp_label:
label = disp_label if len(bands) == 1 else f"|{disp_label}|"
fig.colorbar(im, ax=ax, label=label)
if grid and contour_grid:
ax.grid(axis=grid)
if pi_ticks:
_set_piticks(ax.xaxis, num_ticks=2)
_set_piticks(ax.yaxis, num_ticks=2)
if bz is not None:
draw_lines(ax, bz, color="k")
else:
raise NotImplementedError()
fig.tight_layout()
if show:
plt.show()
return ax
[docs]class DispersionPath:
"""Defines a dispersion path between high symmetry (HS) points.
Examples
--------
Define a path using the add-method or preset points. To get the actual points the
'build'-method is called:
>>> path = DispersionPath(dim=3).add([0, 0, 0], 'Gamma').x(a=1.0).cycle()
>>> vectors = path.build(n_sect=1000)
Attributes
----------
dim : int
labels : list of str
points : list of array_like
n_sect : int
"""
def __init__(self, dim=0):
self.dim = dim
self.labels = list()
self.points = list()
self.n_sect = 0
[docs] @classmethod
def chain_path(cls, a=1.0):
return cls(dim=1).x(a).gamma().cycle()
[docs] @classmethod
def square_path(cls, a=1.0):
return cls(dim=2).gamma().x(a).m(a).cycle()
[docs] @classmethod
def cubic_path(cls, a=1.0):
return cls(dim=3).gamma().x(a).m(a).gamma().r(a)
@property
def num_points(self):
"""int: Number of HS points in the path"""
return len(self.points)
[docs] def add(self, point, name=""):
"""Adds a new HS point to the path
This method returns the instance for easier path definitions.
Parameters
----------
point: array_like
The coordinates of the HS point. If the dimension of the point is
higher than the set dimension the point will be clipped.
name: str, optional
Optional name of the point. If not specified the number of the point
is used.
Returns
-------
self: DispersionPath
"""
if not name:
name = str(len(self.points))
point = np.asarray(point)
if self.dim:
point = point[: self.dim]
else:
self.dim = len(point)
self.points.append(point)
self.labels.append(name)
return self
[docs] def add_points(self, points, names=None):
"""Adds multiple HS points to the path
Parameters
----------
points: array_like
The coordinates of the HS points.
names: list of str, optional
Optional names of the points. If not specified the number of the point
is used.
Returns
-------
self: DispersionPath
"""
if names is None:
names = [""] * len(points)
for point, name in zip(points, names):
self.add(point, name)
return self
[docs] def cycle(self):
"""Adds the first point of the path.
This method returns the instance for easier path definitions.
Returns
-------
self: DispersionPath
"""
self.points.append(self.points[0])
self.labels.append(self.labels[0])
return self
[docs] def gamma(self):
r"""DispersionPath: Adds the .math:'\Gamma=(0, 0, 0)' point to the path"""
return self.add([0, 0, 0], r"$\Gamma$")
[docs] def x(self, a=1.0):
r"""DispersionPath: Adds the .math:'X=(\pi, 0, 0)' point to the path"""
return self.add([np.pi / a, 0, 0], r"$X$")
[docs] def m(self, a=1.0):
r"""DispersionPath: Adds the ,math:'M=(\pi, \pi, 0)' point to the path"""
return self.add([np.pi / a, np.pi / a, 0], r"$M$")
[docs] def r(self, a=1.0):
r"""DispersionPath: Adds the .math:'R=(\pi, \pi, \pi)' point to the path"""
return self.add([np.pi / a, np.pi / a, np.pi / a], r"$R$")
[docs] def build(self, n_sect=1000):
"""Builds the vectors defining the path between the set HS points.
Parameters
----------
n_sect: int, optional
Number of points between each pair of HS points.
Returns
-------
path: (N, D) np.ndarray
"""
self.n_sect = n_sect
path = np.zeros((0, self.dim))
for p0, p1 in chain(self.points):
path = np.append(path, np.linspace(p0, p1, num=n_sect), axis=0)
return path
[docs] def get_ticks(self):
"""Get the positions of the points of the last buildt path.
Mainly used for setting ticks in plot.
Returns
-------
ticks: (N) np.ndarray
labels: (N) list
"""
return np.arange(self.num_points) * self.n_sect, self.labels
[docs] def edges(self):
"""Constructs the edges of the path."""
return list(chain(self.points))
[docs] def distances(self):
"""Computes the distances between the edges of the path."""
dists = list()
for p0, p1 in self.edges():
dists.append(distance(p0, p1))
return np.array(dists)
[docs] def scales(self):
"""Computes the scales of the the edges of the path."""
dists = self.distances()
return dists / dists[0]
[docs] def draw(self, ax, color=None, lw=1.0, **kwargs):
lines = draw_lines(ax, self.edges(), color=color, lw=lw, **kwargs)
return lines
[docs] def subplots(self, xlabel="k", ylabel="E(k)", grid="both"):
"""Creates an empty matplotlib plot with configured axes for the path.
Parameters
----------
xlabel: str, optional
ylabel: str, optional
grid: str, optional
Returns
-------
fig: plt.Figure
ax: plt.Axis
"""
ticks, labels = self.get_ticks()
return bandpath_subplots(ticks, labels, xlabel, ylabel, grid)
[docs] def plot_dispersion(self, disp, ax=None, show=True, **kwargs):
scales = self.scales()
return plot_dispersion(
disp, self.labels, scales=scales, ax=ax, show=show, **kwargs
)
[docs] def plot_disp_dos(self, disp, dos, axs=None, show=True, **kwargs):
scales = self.scales()
return plot_disp_dos(
disp, dos, self.labels, scales=scales, axs=axs, show=show, **kwargs
)