# coding: utf-8
#
# This code is part of lattpy.
#
# Copyright (c) 2022, Dylan Jones
#
# This code is licensed under the MIT License. The copyright notice in the
# LICENSE file in the root directory and this permission notice shall
# be included in all copies or substantial portions of the Software.
"""Contains plotting tools for the lattice and other related objects."""
import itertools
import numpy as np
from collections.abc import Iterable
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import matplotlib.style as mpl_style
from matplotlib.lines import Line2D
from matplotlib.collections import LineCollection, Collection
from mpl_toolkits.mplot3d.art3d import Line3DCollection, Line3D, Poly3DCollection
from matplotlib.artist import allow_rasterization
from matplotlib import path, transforms
from typing import List
import colorcet as cc
__all__ = [
"subplot",
"draw_line",
"draw_lines",
"hide_box",
"draw_arrows",
"draw_vectors",
"draw_points",
"draw_indices",
"draw_unit_cell",
"draw_surfaces",
"interpolate_to_grid",
"draw_sites",
"connection_color_array",
]
# Golden ratio as standard ratio for plot-figures
GOLDEN_RATIO = (np.sqrt(5) - 1.0) / 2.0
# ======================================================================================
# Formatting / Styling
# ======================================================================================
def set_color_cycler(color_cycle=cc.glasbey_category10):
"""Sets the colors of the pyplot color cycler.
Parameters
----------
color_cycle : Sequence
A list of the colors to use in the prop cycler.
"""
plt.rcParams["axes.prop_cycle"] = plt.cycler("color", color_cycle)
def use_mplstyle(style, color_cycle=None):
"""Update matplotlib rcparams according to style.
Parameters
----------
style : str or dict or Path or Iterable
The style configuration.
color_cycle : Sequence, optional
A list of the colors to use in the prop cycler.
"""
mpl_style.use(style)
if color_cycle is not None:
plt.rcParams["axes.prop_cycle"] = plt.cycler("color", color_cycle)
def set_equal_aspect(ax=None, adjustable="box"):
"""Sets the aspect ratio of the plot to equal.
Parameters
----------
ax : Axes
The axes of the plot. If not given the current axes is used.
adjustable : None or {'box', 'datalim'}, optional
If not None, this defines which parameter will be adjusted to meet
the equal aspect ratio. If 'box', change the physical dimensions of
the Axes. If 'datalim', change the x or y data limits.
Notes
-----
Setting the aspect ratio to equal is not supported for 3D plots and will
be ignored in that case.
"""
if ax is None:
ax = plt.gca()
if ax.name == "3d":
return
ax.set_aspect("equal", adjustable)
[docs]def hide_box(ax, axis=False):
"""Remove the box and optionally the axis of a plot.
Parameters
----------
ax : Axes
The axes to remove the box.
axis : bool, optional
If True the axis are hiden as well as the box.
"""
if ax.name == "3d":
return
for side in ["top", "right"]:
ax.spines[side].set_visible(False)
ax.xaxis.tick_bottom()
ax.yaxis.tick_left()
if axis:
for side in ["left", "bottom"]:
ax.spines[side].set_visible(False)
ax.xaxis.set_ticks_position("none")
ax.yaxis.set_ticks_position("none")
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks([])
ax.set_yticks([])
# ======================================================================================
# General Plotting
# ======================================================================================
[docs]def subplot(dim, adjustable="box", ax=None):
"""Generates a two- or three-dimensional subplot with a equal aspect ratio
Parameters
----------
dim : int
The dimension of the plot.
adjustable : None or {'box', 'datalim'}, optional
If not None, this defines which parameter will be adjusted to meet
the equal aspect ratio. If 'box', change the physical dimensions of
the Axes. If 'datalim', change the x or y data limits.
Only applied to 2D plots.
ax : Axes, optional
Existing axes to format. If an existing axes is passed no new figure is created.
Returns
-------
fig : Figure
The figure of the subplot.
ax : Axes
The newly created or formatted axes of the subplot.
"""
if dim > 3:
raise ValueError(f"Plotting in {dim} dimensions is not supported!")
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d" if dim == 3 else None)
else:
fig = ax.get_figure()
set_equal_aspect(ax, adjustable)
return fig, ax
# noinspection PyShadowingNames
[docs]def draw_line(ax, points, **kwargs):
"""Draw a line segment between multiple points.
Parameters
----------
ax : Axes
The axes for drawing the line segment.
points : (N, D) np.ndarray
A list of points between the line is drawn.
**kwargs
Additional keyword arguments for drawing the line.
Returns
-------
coll : Line2D or Line3D
The created line.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> points = np.array([[1, 0], [0.7, 0.7], [0, 1], [-0.7, 0.7], [-1, 0]])
>>> _ = plotting.draw_line(ax, points)
>>> ax.margins(0.1, 0.1)
>>> plt.show()
"""
dim = len(points[0])
if dim < 3:
line = Line2D(*points.T, **kwargs)
elif dim == 3:
line = Line3D(*points.T, **kwargs)
else:
raise ValueError(f"Can't draw line with dimension {dim}")
ax.add_line(line)
return line
# noinspection PyShadowingNames
[docs]def draw_lines(ax, segments, **kwargs):
"""Draw multiple line segments between points.
Parameters
----------
ax : Axes
The axes for drawing the lines.
segments : array_like of (2, D) np.ndarray
A list of point pairs between the lines are drawn.
**kwargs
Additional keyword arguments for drawing the lines.
Returns
-------
coll: LineCollection or Line3DCollection
The created line collection.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> segments = np.array([
... [[0, 0], [1, 0]],
... [[0, 1], [1, 1]],
... [[0, 2], [1, 2]]
... ])
>>> _ = plotting.draw_lines(ax, segments)
>>> ax.margins(0.1, 0.1)
>>> plt.show()
"""
dim = len(segments[0][0])
if dim < 3:
coll = LineCollection(segments, **kwargs)
elif dim == 3:
coll = Line3DCollection(segments, **kwargs)
else:
raise ValueError(f"Can't draw lines with dimension {dim}")
ax.add_collection(coll)
return coll
# noinspection PyShadowingNames
[docs]def draw_vectors(ax, vectors, pos=None, **kwargs):
"""Draws multiple lines from an optional starting point in the given directions.
Parameters
----------
ax : Axes
The axes for drawing the lines.
vectors : (N, D) np.ndarray
The vectors to draw.
pos : (D, ) np.ndarray, optional
The starting position of the vectors. The default is the origin.
**kwargs
Additional keyword arguments for drawing the lines.
Returns
-------
coll: LineCollection or Line3DCollection
The created line collection.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> vectors = np.array([[1, 0], [0.7, 0.7], [0, 1], [-0.7, 0.7], [-1, 0]])
>>> _ = plotting.draw_vectors(ax, vectors, [1, 0])
>>> ax.margins(0.1, 0.1)
>>> plt.show()
"""
pos = pos if pos is not None else np.zeros(len(vectors[0]))
vectors = np.atleast_2d(vectors)
# Fix 1D case
if vectors.shape[1] == 1:
vectors = np.hstack((vectors, np.zeros((vectors.shape[0], 1))))
pos = np.array([pos[0], 0])
# Build segments
segments = list()
for v in vectors:
segments.append([pos, pos + v])
return draw_lines(ax, segments, **kwargs)
# noinspection PyShadowingNames
[docs]def draw_arrows(ax, vectors, pos=None, **kwargs):
"""Draws multiple arrows from an optional starting point in the given directions.
Parameters
----------
ax : Axes
The axes for drawing the arrows.
vectors : (N, D) np.ndarray
The vectors to draw.
pos : (D, ) np.ndarray, optional
The starting position of the vectors. The default is the origin.
**kwargs
Additional keyword arguments for drawing the arrows.
Returns
-------
coll: LineCollection or Line3DCollection
The created line collection.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> vectors = np.array([[1, 0], [0.7, 0.7], [0, 1], [-0.7, 0.7], [-1, 0]])
>>> _ = plotting.draw_arrows(ax, vectors)
>>> ax.margins(0.1, 0.1)
>>> plt.show()
"""
vectors = np.atleast_2d(vectors)
num_vecs, dim = vectors.shape
if pos is None:
pos = np.zeros((num_vecs, dim))
else:
pos = np.atleast_2d(pos)
if pos.shape[0] == 1:
pos = np.tile(pos, (num_vecs, 1))
assert len(pos) == len(vectors)
points = pos.T
directions = vectors.T
end_points = (pos + vectors).T
# Plot invisible points for datalim
if dim == 1:
end_points = np.append(end_points, np.zeros_like(end_points), axis=0)
points = np.append(points, np.zeros_like(points), axis=0)
directions = np.append(directions, np.zeros_like(directions), axis=0)
ax.scatter(*end_points, s=0)
# Draw arrows as quiver plot
if dim != 3:
kwargs.update({"angles": "xy", "scale_units": "xy", "scale": 1})
else:
kwargs.update({"normalize": False})
return ax.quiver(*points, *directions, **kwargs)
# noinspection PyShadowingNames
[docs]def draw_points(ax, points, size=10, **kwargs):
"""Draws multiple points as scatter plot.
Parameters
----------
ax : Axes
The axes for drawing the points.
points : (N, D) np.ndarray
The positions of the points to draw.
size : float, optional
The size of the markers of the points.
**kwargs
Additional keyword arguments for drawing the points.
Returns
-------
scat : PathCollection
The scatter plot item.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> points = np.array([[1, 0], [0.7, 0.7], [0, 1], [-0.7, 0.7], [-1, 0]])
>>> _ = plotting.draw_points(ax, points)
>>> ax.margins(0.1, 0.1)
>>> plt.show()
"""
points = np.atleast_2d(points)
# Fix 1D case
if points.shape[1] == 1:
points = np.hstack((points, np.zeros((points.shape[0], 1))))
scat = ax.scatter(*points.T, s=size**2, **kwargs)
# Manualy update data-limits
# ax.ignore_existing_data_limits = True
datalim = scat.get_datalim(ax.transData)
ax.update_datalim(datalim)
return scat
# noinspection PyShadowingNames
[docs]def draw_surfaces(ax, vertices, **kwargs):
"""Draws a 3D surfaces defined by a set of vertices.
Parameters
----------
ax : Axes3D
The axes for drawing the surface.
vertices : array_like
The vertices defining the surface.
**kwargs
Additional keyword arguments for drawing the lines.
Returns
-------
surf : Poly3DCollection
The surface object.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> vertices = [[0, 0, 0], [1, 1, 0], [0.5, 0.5, 1]]
>>> fig = plt.figure()
>>> ax = fig.add_subplot(111, projection="3d")
>>> _ = plotting.draw_surfaces(ax, vertices, alpha=0.5)
>>> plt.show()
"""
if not isinstance(vertices[0][0], Iterable):
vertices = [list(vertices)]
poly = Poly3DCollection(vertices, **kwargs)
ax.add_collection3d(poly)
return poly
# noinspection PyShadowingNames
def text(ax, strings, positions, offset=None, **kwargs):
"""Adds multiple strings to a plot.
Parameters
----------
ax : Axes
The axes for drawing the text.
strings : str or sequence of str
The text to render.
positions : (..., D) array_like
The positions of the texts.
offset : float or (D, ) array_like
The offset of the positions of the text.
**kwargs
Additional keyword arguments for drawing the text.
Returns
-------
texts : list
The text items.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> points = np.array([[-1, 0], [-0.7, 0.7], [0, 1], [0.7, 0.7], [1, 0]])
>>> strings = ["A", "B", "C", "D", "E"]
>>> _ = plotting.text(ax, strings, points)
>>> _ = ax.set_xlim(-1.5, +1.5)
>>> _ = ax.set_ylim(-0.5, +1.5)
>>> plt.show()
"""
positions = np.atleast_2d(positions)
texts = list()
if offset is None:
offset = np.zeros(max(2, len(positions[0])))
elif isinstance(offset, float):
offset = offset * np.ones(max(2, len(positions[0])))
for s, pos in zip(strings, positions):
if len(pos) == 1:
pos = [pos, 0]
tpos = np.asarray(pos) + offset
txt = ax.text(*tpos, s=s, **kwargs)
texts.append(txt)
return texts
# ======================================================================================
# Lattice plotting
# ======================================================================================
# noinspection PyAbstractClass
class CircleCollection(Collection):
"""Custom circle collection
The default matplotlib `CircleCollection` creates circles based on their
area in screen units. This class uses the radius in data units. It behaves
like a much faster version of a `PatchCollection` of `Circle`.
The implementation is similar to `EllipseCollection`.
"""
def __init__(self, radius, **kwargs):
super().__init__(**kwargs)
self.radius = np.atleast_1d(radius)
self._paths = [path.Path.unit_circle()]
self.set_transform(transforms.IdentityTransform())
self._transforms = np.empty((0, 3, 3))
def _set_transforms(self):
ax = self.axes
self._transforms = np.zeros((self.radius.size, 3, 3))
self._transforms[:, 0, 0] = self.radius * ax.bbox.width / ax.viewLim.width
self._transforms[:, 1, 1] = self.radius * ax.bbox.height / ax.viewLim.height
self._transforms[:, 2, 2] = 1
@allow_rasterization
def draw(self, renderer):
self._set_transforms()
super().draw(renderer)
# noinspection PyShadowingNames
[docs]def draw_sites(ax, points, radius=0.2, **kwargs):
"""Draws multiple circles with a scaled radius.
Parameters
----------
ax : Axes
The axes for drawing the points.
points : (N, D) np.ndarray
The positions of the points to draw.
radius : float
The radius of the points. Scaling is only supported for 2D plots!
**kwargs
Additional keyword arguments for drawing the points.
Returns
-------
point_coll : CircleCollection or PathCollection
The circle or path collection.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> points = np.array([[1, 0], [0.7, 0.7], [0, 1], [-0.7, 0.7], [-1, 0]])
>>> _ = plotting.draw_sites(ax, points, radius=0.2)
>>> _ = ax.set_xlim(-1.5, +1.5)
>>> _ = ax.set_ylim(-0.5, +1.5)
>>> plotting.set_equal_aspect(ax)
>>> plt.show()
"""
points = np.atleast_2d(points)
# Fix 1D case
if points.shape[1] == 1:
points = np.hstack((points, np.zeros((points.shape[0], 1))))
dim = points.shape[1]
if dim < 3:
col = CircleCollection(
radius, offsets=points, transOffset=ax.transData, **kwargs
)
ax.add_collection(col)
label = kwargs.get("label", "")
if label:
ax.plot(
[],
[],
marker="o",
lw=0,
color=kwargs.get("color", None),
label=label,
markersize=10,
)
datalim = col.get_datalim(ax.transData)
datalim.x0 -= radius
datalim.x1 += radius
datalim.y0 -= radius
datalim.y1 += radius
ax.update_datalim(datalim)
return col
else:
size = radius * 50
scat = ax.scatter(*points.T, s=size**2, **kwargs)
# Manualy update data-limits
# ax.ignore_existing_data_limits = True
datalim = scat.get_datalim(ax.transData)
ax.update_datalim(datalim)
return scat
[docs]def connection_color_array(num_base, default="k", colors=None) -> List[List]:
"""Construct color array for the connections between all atoms in a lattice.
Parameters
----------
num_base : int
The number of atoms in the unit cell of a lattice.
default : str or int or float or tuple
The default color of the connections.
colors : Sequence[tuple], optional
list of colors to override the defautl connection color. Each element
has to be a tuple with the first two elements being the atom indices of
the pair and the third element the color, for example ``[(0, 0, 'r')]``.
Returns
-------
color_array : List of List
The connection color array
Examples
--------
>>> connection_color_array(2, "k", colors=[(0, 1, "r")])
[['k', 'r'], ['r', 'k']]
"""
alphas = range(num_base)
hop_colors = [[default for _ in alphas] for _ in alphas]
if colors is not None and colors:
for a1, a2, col in colors:
hop_colors[a1][a2] = col
hop_colors[a2][a1] = col
return hop_colors
# noinspection PyShadowingNames
[docs]def draw_indices(ax, positions, offset=0.05, **kwargs):
"""Draws the indices of the given positions on the plot.
Parameters
----------
ax : Axes
The axes for drawing the text.
positions : (..., D) array_like
The positions of the texts.
offset : float or (D, ) array_like
The offset of the positions of the texts.
**kwargs
Additional keyword arguments for drawing the text.
Returns
-------
texts : list
The text items.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> points = np.array([[-1, 0], [-0.7, 0.7], [0, 1], [0.7, 0.7], [1, 0]])
>>> fig, ax = plt.subplots()
>>> _ = plotting.draw_points(ax, points)
>>> _ = plotting.draw_indices(ax, points)
>>> ax.margins(0.1, 0.1)
>>> plt.show()
"""
strings = [str(i) for i in range(len(positions))]
va = "bottom"
ha = "left"
return text(ax, strings, positions, offset, ha=ha, va=va, **kwargs)
# noinspection PyShadowingNames
[docs]def draw_unit_cell(ax, vectors, outlines=True, **kwargs):
"""Draws the basis vectors and unit cell.
Parameters
----------
ax : Axes
The axes for drawing the text.
vectors : float or (D, D) array_like
The vectors defining the basis.
outlines : bool, optional
If True the box define dby the basis vectors (unit cell) is drawn.
**kwargs
Additional keyword arguments for drawing the lines.
Returns
-------
lines : list
A list of the plotted lines.
Examples
--------
>>> from lattpy import plotting
>>> import matplotlib.pyplot as plt
>>> vectors = np.array([[1, 0], [0, 1]])
>>> fig, ax = plt.subplots()
>>> _ = plotting.draw_unit_cell(ax, vectors)
>>> plt.show()
"""
dim = len(vectors)
color = kwargs.pop("color", "k")
arrows = draw_arrows(ax, vectors, color=color, **kwargs)
lines = list()
if outlines and dim > 1:
for v, pos in itertools.permutations(vectors, r=2):
data = np.asarray([pos, pos + v]).T
line = ax.plot(*data, color=color, **kwargs)[0]
lines.append(line)
if dim == 3:
for vecs in itertools.permutations(vectors, r=3):
v, pos = vecs[0], np.sum(vecs[1:], axis=0)
data = np.asarray([pos, pos + v]).T
line = ax.plot(*data, color=color, **kwargs)[0]
lines.append(line)
return arrows, lines
[docs]def interpolate_to_grid(
positions,
values,
num=(100, 100),
offset=(0.0, 0.0),
method="linear",
fill_value=np.nan,
):
x, y = positions.T
# Create regular grid
xi = np.linspace(min(x) - offset[0], max(x) + offset[0], num[0])
yi = np.linspace(min(y) - offset[1], max(y) + offset[1], num[1])
# Interpolate data to grid
xx, yy = np.meshgrid(xi, yi)
zz = griddata((x, y), values, (xi[None, :], yi[:, None]), method, fill_value)
return xx, yy, zz