# coding: utf-8
#
# This code is part of lattpy.
#
# Copyright (c) 2022, Dylan Jones
#
# This code is licensed under the MIT License. The copyright notice in the
# LICENSE file in the root directory and this permission notice shall
# be included in all copies or substantial portions of the Software.
"""Basis object for defining the coordinate system and unit cell of a lattice."""
import logging
import itertools
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
from typing import Union, Sequence, Callable
from .spatial import cell_size, cell_volume, WignerSeitzCell
from .plotting import subplot, draw_unit_cell
logger = logging.getLogger(__name__)
vecs_t = Union[float, Sequence[float], Sequence[Sequence[float]]]
basis_t = Union[float, Sequence[float], Sequence[Sequence[float]], "LatticeBasis"]
__all__ = ["basis_t", "LatticeBasis"]
[docs]class LatticeBasis:
"""Lattice basis for representing the coordinate system and unit cell of a lattice.
The ``LatticeBasis`` object is the core of any lattice model. It defines the
basis vectors and subsequently the coordinate system of the lattice and provides
the necessary basis transformations between the world and lattice coordinate system.
Attributes
----------
dim
vectors
vectors3d
norms
cell_size
cell_volume
Methods
-------
chain
square
rectangular
oblique
hexagonal
hexagonal3d
sc
fcc
bcc
itransform
transform
itranslate
translate
is_reciprocal
reciprocal_vectors
reciprocal_lattice
get_neighbor_cells
wigner_seitz_cell
brillouin_zone
plot_basis
Parameters
----------
basis: array_like or float or LatticeBasis
The primitive basis vectors that define the unit cell of the lattice. If a
``LatticeBasis`` instance is passed it is copied and used as the new basis.
**kwargs
Key-word arguments. Used only when subclassing ``LatticeBasis``.
Examples
--------
>>> import lattpy as lp
>>> import matplotlib.pyplot as plt
>>> basis = lp.LatticeBasis.square()
>>> _ = basis.plot_basis()
>>> plt.show()
"""
# Tolerance for reciprocal vectors/lattice
RVEC_TOLERANCE: float = 1e-6
# noinspection PyUnusedLocal
def __init__(self, basis: basis_t, **kwargs):
if isinstance(basis, LatticeBasis):
basis = basis.vectors
# Vector basis
self._vectors = np.atleast_2d(basis).T
self._vectors_inv = np.linalg.inv(self._vectors)
self._dim = len(self._vectors)
self._cell_size = cell_size(self.vectors)
self._cell_volume = cell_volume(self.vectors)
[docs] @classmethod
def chain(cls, a: float = 1.0, **kwargs):
"""Initializes a one-dimensional lattice."""
return cls(a, **kwargs)
[docs] @classmethod
def square(cls, a: float = 1.0, **kwargs):
"""Initializes a 2D lattice with square basis vectors."""
return cls(a * np.eye(2), **kwargs)
[docs] @classmethod
def rectangular(cls, a1: float = 1.0, a2: float = 1.0, **kwargs):
"""Initializes a 2D lattice with rectangular basis vectors."""
return cls(np.array([[a1, 0], [0, a2]]), **kwargs)
[docs] @classmethod
def oblique(cls, alpha: float, a1: float = 1.0, a2: float = 1.0, **kwargs):
"""Initializes a 2D lattice with oblique basis vectors."""
vectors = np.array([[a1, 0], [a2 * np.cos(alpha), a2 * np.sin(alpha)]])
return cls(vectors, **kwargs)
[docs] @classmethod
def hexagonal(cls, a: float = 1.0, **kwargs):
"""Initializes a 2D lattice with hexagonal basis vectors."""
vectors = a / 2 * np.array([[3, np.sqrt(3)], [3, -np.sqrt(3)]])
return cls(vectors, **kwargs)
[docs] @classmethod
def hexagonal3d(cls, a: float = 1.0, az: float = 1.0, **kwargs):
"""Initializes a 3D lattice with hexagonal basis vectors."""
vectors = (
a / 2 * np.array([[3, np.sqrt(3), 0], [3, -np.sqrt(3), 0], [0, 0, az]])
)
return cls(vectors, **kwargs)
[docs] @classmethod
def sc(cls, a: float = 1.0, **kwargs):
"""Initializes a 3D simple cubic lattice."""
return cls(a * np.eye(3), **kwargs)
[docs] @classmethod
def fcc(cls, a: float = 1.0, **kwargs):
"""Initializes a 3D face centered cubic lattice."""
vectors = a / 2 * np.array([[1, 1, 0], [1, 0, 1], [0, 1, 1]])
return cls(vectors, **kwargs)
[docs] @classmethod
def bcc(cls, a: float = 1.0, **kwargs):
"""Initializes a 3D body centered cubic lattice."""
vectors = a / 2 * np.array([[1, 1, 1], [1, -1, 1], [-1, 1, 1]])
return cls(vectors, **kwargs)
[docs] @classmethod
def hypercubic(cls, dim, a: float = 1.0, **kwargs):
"""Creates a d-dimensional cubic lattice."""
vectors = a * np.eye(dim)
return cls(vectors, **kwargs)
# ==================================================================================
@property
def dim(self) -> int:
"""int: The dimension of the lattice."""
return self._dim
@property
def vectors(self) -> np.ndarray:
"""np.ndarray: Array containing the basis vectors as rows."""
return self._vectors.T
@property
def vectors3d(self) -> np.ndarray:
"""np.ndarray: The basis vectors expanded to three dimensions."""
vectors = np.eye(3)
vectors[: self.dim, : self.dim] = self._vectors
return vectors.T
@property
def norms(self) -> np.ndarray:
"""np.ndarray: Lengths of the basis vectors."""
return np.linalg.norm(self._vectors, axis=0)
@property
def cell_size(self) -> np.ndarray:
"""np.ndarray: The shape of the box spawned by the basis vectors."""
return self._cell_size
@property
def cell_volume(self) -> float:
"""float: The volume of the unit cell defined by the basis vectors."""
return self._cell_volume
[docs] def translate(
self,
nvec: Union[int, Sequence[int], Sequence[Sequence[int]]],
r: Union[float, Sequence[float]] = 0.0,
) -> np.ndarray:
r"""Translates the given postion vector ``r`` by the translation vector ``n``.
The position is calculated using the translation vector :math:`n` and the
atom position in the unitcell :math:`r`:
.. math::
R = \sum_i n_i v_i + r
Parameters
----------
nvec : (..., N) array_like
Translation vector in the lattice coordinate system.
r : (N) array_like, optional
The position in cartesian coordinates. If no vector is passed only
the translation is returned.
Returns
-------
r_trans : (..., N) np.ndarray
The translated position.
Examples
--------
Construct a lattice with basis vectors :math:`a_1 = (2, 0)` and
:math:`a_2 = (0, 1)`:
>>> latt = LatticeBasis([[2, 0], [0, 1]])
Translate the origin:
>>> n = [1, 0]
>>> latt.translate(n)
[2. 0.]
Translate a point:
>>> p = [0.5, 0.5]
>>> latt.translate(n, p)
[2.5 0.5]
Translate a point by multiple translation vectors:
>>> p = [0.5, 0.5]
>>> nvecs = [[0, 0], [1, 0], [2, 0]]
>>> latt.translate(nvecs, p)
[[0.5 0.5]
[2.5 0.5]
[4.5 0.5]]
"""
r = np.atleast_1d(r)
nvec = np.atleast_1d(nvec)
return r + np.inner(nvec, self._vectors)
[docs] def itranslate(self, x: Union[float, Sequence[float]]) -> [np.ndarray, np.ndarray]:
"""Returns the translation vector and atom position of the given position.
Parameters
----------
x : (..., N) array_like or float
Position vector in cartesian coordinates.
Returns
-------
nvec : (..., N) np.ndarray
Translation vector in the lattice basis.
r : (..., N) np.ndarray, optional
The position in real-space.
Examples
--------
Construct a lattice with basis vectors :math:`a_1 = (2, 0)` and
:math:`a_2 = (0, 1)`:
>>> latt = LatticeBasis([[2, 0], [0, 1]])
>>> latt.itranslate([2, 0])
(array([1., 0.]), array([0., 0.]))
>>> latt.itranslate([2.5, 0.5])
(array([1., 0.]), array([0.5, 0.5]))
"""
x = np.atleast_1d(x)
itrans = self.itransform(x)
nvec = np.floor(itrans).astype(np.int64)
r = x - self.translate(nvec)
return nvec, r
[docs] def is_reciprocal(self, vecs: vecs_t, tol: float = RVEC_TOLERANCE) -> bool:
r"""Checks if the given vectors are reciprocal to the lattice vectors.
The lattice- and reciprocal vectors :math:`a_i` and :math:`b_i` must satisfy
the relation
.. math::
a_i \cdot b_i = 2 \pi \delta_{ij}
To check the given vectors, the difference of each dot-product is compared to
:math:`2\pi` with the given tolerance.
Parameters
----------
vecs : array_like or float
The vectors to check. Must have the same dimension as the lattice.
tol : float, optional
The tolerance used for checking the result of the dot-products.
Returns
-------
is_reciprocal : bool
Flag if the vectors are reciprocal to the lattice basis vectors.
"""
vecs = np.atleast_2d(vecs)
two_pi = 2 * np.pi
for a, b in zip(self.vectors, vecs):
if abs(np.dot(a, b) - two_pi) > tol:
return False
return True
[docs] def reciprocal_vectors(
self, tol: float = RVEC_TOLERANCE, check: bool = False
) -> np.ndarray:
r"""Computes the reciprocal basis vectors of the bravais lattice.
The lattice- and reciprocal vectors :math:`a_i` and :math:`b_i` must satisfy
the relation
.. math::
a_i \cdot b_i = 2 \pi \delta_{ij}
Parameters
----------
tol : float, optional
The tolerance used for checking the result of the dot-products.
check : bool, optional
Check the result and raise an exception if it does not satisfy
the definition.
Returns
-------
v_rec : np.ndarray
The reciprocal basis vectors of the lattice.
Examples
--------
Reciprocal vectors of the square lattice:
>>> latt = LatticeBasis(np.eye(2))
>>> latt.reciprocal_vectors()
[[6.28318531 0. ]
[0. 6.28318531]]
"""
two_pi = 2 * np.pi
if self.dim == 1:
return np.array([[two_pi / self.vectors[0, 0]]])
# Convert basis vectors of the bravais lattice to 3D, compute
# reciprocal vectors and convert back to actual dimension.
a1, a2, a3 = self.vectors3d
factor = 2 * np.pi / self.cell_volume
b1 = np.cross(a2, a3)
b2 = np.cross(a3, a1)
b3 = np.cross(a1, a2)
rvecs = factor * np.asarray([b1, b2, b3])
rvecs = rvecs[: self.dim, : self.dim]
# Fix the sign so that the dot-products are all positive
# and raise an exception if anything went wrong
vecs = self.vectors
for i in range(self.dim):
dot = np.dot(vecs[i], rvecs[i])
# Check if dot product is - 2 pi
if abs(dot + two_pi) <= tol:
rvecs[i] *= -1
# Raise an exception if checks are enabled and
# dot product results in anything other than +2 pi
elif check and abs(dot - two_pi) > tol:
raise ValueError(f"{rvecs[i]} not a reciprocal vector to {vecs[i]}")
return rvecs
# noinspection PyShadowingNames
[docs] def reciprocal_lattice(self, min_negative: bool = False):
"""Creates the lattice in reciprocal space.
Parameters
----------
min_negative : bool, optional
If True the reciprocal vectors are scaled such that
there are fewer negative elements than positive ones.
Returns
-------
rlatt : LatticeBasis
The lattice in reciprocal space
See Also
--------
reciprocal_vectors : Constructs the reciprocal vectors used for the
reciprocal lattice
Examples
--------
Reciprocal lattice of the square lattice:
>>> latt = LatticeBasis(np.eye(2))
>>> rlatt = latt.reciprocal_lattice()
>>> rlatt.vectors
[[6.28318531 0. ]
[0. 6.28318531]]
"""
rvecs = self.reciprocal_vectors(min_negative)
rlatt = self.__class__(rvecs)
return rlatt
[docs] def get_neighbor_cells(
self,
distidx: int = 0,
include_origin: bool = True,
comparison: Callable = np.isclose,
) -> np.ndarray:
"""Find all neighboring unit cells of the unit cell at the origin.
Parameters
----------
distidx : int, default
Index of distance to neighboring cells, default is 0 (nearest neighbors).
include_origin : bool, optional
If True the origin is included in the set.
comparison : callable, optional
The method used for comparing distances.
Returns
-------
indices : np.ndarray
The lattice indeices of the neighboring unit cells.
Examples
--------
>>> latt = LatticeBasis(np.eye(2))
>>> latt.get_neighbor_cells(distidx=0, include_origin=False)
[[-1 0]
[ 0 -1]
[ 0 1]
[ 1 0]]
"""
# Build cell points
max_factor = distidx + 1
axis_factors = np.arange(-max_factor, max_factor + 1)
factors = np.array(list(itertools.product(axis_factors, repeat=self.dim)))
points = np.dot(factors, self.vectors[np.newaxis, :, :])[:, 0, :]
# Compute distances to origin for all points
distances = np.linalg.norm(points, axis=1)
# Set maximum distances value to number of neighbors
# + number of unique vector lengths
max_distidx = distidx + len(np.unique(np.linalg.norm(self.vectors, axis=1)))
# Return points with distance lower than maximum distance
maxdist = np.sort(np.unique(distances))[max_distidx]
indices = np.where(comparison(distances, maxdist))[0]
factors = factors[indices]
origin = np.zeros(self.dim)
idx = np.where((factors == origin).all(axis=1))[0]
if include_origin and not len(idx):
factors = np.append(origin[np.newaxis, :], factors, axis=0)
elif not include_origin and len(idx):
factors = np.delete(factors, idx, axis=0)
return factors
[docs] def wigner_seitz_cell(self) -> WignerSeitzCell:
"""Computes the Wigner-Seitz cell of the lattice structure.
Returns
-------
ws_cell : WignerSeitzCell
The Wigner-Seitz cell of the lattice.
"""
nvecs = self.get_neighbor_cells(include_origin=True)
positions = np.dot(nvecs, self.vectors[np.newaxis, :, :])[:, 0, :]
return WignerSeitzCell(positions)
[docs] def brillouin_zone(self, min_negative: bool = False) -> WignerSeitzCell:
"""Computes the first Brillouin-zone of the lattice structure.
Constructs the Wigner-Seitz cell of the reciprocal lattice
Parameters
----------
min_negative : bool, optional
If True the reciprocal vectors are scaled such that
there are fewer negative elements than positive ones.
Returns
-------
ws_cell : WignerSeitzCell
The Wigner-Seitz cell of the reciprocal lattice.
"""
rvecs = self.reciprocal_vectors(min_negative)
rlatt = self.__class__(rvecs)
return rlatt.wigner_seitz_cell()
[docs] def get_cell_superindex(self, index, shape):
"""Converts a cell index to a super-index for the given shape.
The index of a unit cell is the translation vector.
Parameters
----------
index : array_like or (N, D) np.ndarray
One or multiple cell indices. If a numpy array is passed the result will
be an array of super-indices, otherwise an integer is returned.
shape : (D, ) array_like
The shape for converting the index.
Returns
-------
super_index : int or (N, ) int np.ndarray
The super index. If `index` is a numpy array, `super_index` is an array
of super indices, otherwise it is an integer.
"""
single = not isinstance(index, np.ndarray)
index = np.atleast_2d(index)
super_indices = np.zeros((len(index)), dtype=np.int64)
for i, ind in enumerate(index):
si = sum(ind[d] * np.prod(shape[d + 1 :]) for d in range(self.dim))
super_indices[i] = si
return super_indices[0] if single else super_indices
[docs] def get_cell_index(self, super_index, shape):
"""Converts a super index to the corresponding cell index for the given shape.
The index of a unit cell is the translation vector.
Parameters
----------
super_index : int or (N,) int np.ndarray
One or multiple super indices.
shape : (D, ) array_like
The shape for converting the index.
Returns
-------
index : np.ndarray
The cell index. Of shape (D,) if `super_index` is an integer, otherwise
of shape (N, D).
"""
single = isinstance(super_index, int)
super_index = np.atleast_1d(super_index)
indices = np.zeros((len(super_index), self.dim), dtype=np.int64)
for i, si in enumerate(super_index):
rest = si
for d in range(self.dim - 1):
val, rest = divmod(rest, np.prod(shape[d + 1 :]))
indices[i, d] = val
indices[i, self.dim - 1] = rest
return indices[0] if single else indices
[docs] def plot_basis(
self,
lw: float = None,
ls: str = "--",
margins: Union[Sequence[float], float] = 0.1,
grid: bool = False,
show_cell: bool = True,
show_vecs: bool = True,
adjustable: str = "box",
ax: Union[plt.Axes, Axes3D] = None,
show: bool = False,
) -> Union[plt.Axes, Axes3D]: # pragma: no cover
"""Plot the lattice basis.
Parameters
----------
lw : float, optional
The line width used for plotting the unit cell outlines.
ls : str, optional
The line style used for plotting the unit cell outlines.
margins : Sequence[float] or float, optional
The margins of the plot.
grid : bool, optional
If True, draw a grid in the plot.
show_vecs : bool, optional
If True the first unit-cell is drawn.
show_cell : bool, optional
If True the outlines of the unit cell are plotted.
adjustable : None or {'box', 'datalim'}, optional
If not None, this defines which parameter will be adjusted to meet
the equal aspect ratio. If 'box', change the physical dimensions of
the Axes. If 'datalim', change the x or y data limits.
Only applied to 2D plots.
ax : plt.Axes or plt.Axes3D or None, optional
Parent plot. If None, a new plot is initialized.
show : bool, optional
If True, show the resulting plot.
"""
logger.debug("Plotting unit cell")
if self.dim > 3:
raise ValueError(f"Plotting in {self.dim} dimensions is not supported!")
hopz, atomz = range(2)
fig, ax = subplot(self.dim, adjustable, ax)
# Draw unit vectors and the cell they spawn.
if show_vecs:
vectors = self.vectors
draw_unit_cell(ax, vectors, show_cell, lw=lw, color="k", ls=ls, zorder=hopz)
# Configure grid
if grid and self.dim < 3:
ax.set_axisbelow(True)
ax.grid(b=True, which="major")
# Adjust margin
if isinstance(margins, float):
margins = [margins] * self.dim
ax.margins(*margins)
fig.tight_layout()
if show:
plt.show()
return ax
def __repr__(self) -> str:
return f"{self.__class__.__name__}(dim: {self.dim})"