# coding: utf-8
#
# This code is part of lattpy.
#
# Copyright (c) 2022, Dylan Jones
#
# This code is licensed under the MIT License. The copyright notice in the
# LICENSE file in the root directory and this permission notice shall
# be included in all copies or substantial portions of the Software.
"""This module contains the main `Lattice` object."""
import pickle
import logging
import warnings
import itertools
import numpy as np
from copy import deepcopy
from scipy.sparse import csr_matrix
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
from typing import Union, Optional, Tuple, Iterator, Sequence, Callable
from .utils import ArrayLike, frmt_num, NotBuiltError
from .plotting import (
subplot,
draw_sites,
draw_vectors,
draw_indices,
connection_color_array,
)
from .spatial import KDTree, distances
from .atom import Atom
from .data import LatticeData, DataMap
from .shape import AbstractShape, Shape
from .basis import basis_t
from .structure import LatticeStructure
__all__ = ["Lattice"]
logger = logging.getLogger(__name__)
def _filter_dangling(indices, positions, neighbors, dists, min_neighbors):
num_neighbors = np.count_nonzero(np.isfinite(dists), axis=1)
sites = np.where(num_neighbors < min_neighbors)[0]
if len(sites) == 0:
return indices, positions, neighbors, dists
elif len(sites) == indices.shape[0]:
raise ValueError("Filtering min_neighbors would result in no sites!")
# store current invalid index
invalid_idx = indices.shape[0]
# Remove data from arrays
indices = np.delete(indices, sites, axis=0)
positions = np.delete(positions, sites, axis=0)
neighbors = np.delete(neighbors, sites, axis=0)
dists = np.delete(dists, sites, axis=0)
# Update neighbor indices and distances:
# For each removed site below the neighbor index has to be decremented once
mask = np.isin(neighbors, sites)
neighbors[mask] = invalid_idx
dists[mask] = np.inf
for count, i in enumerate(sorted(sites)):
neighbors[neighbors > (i - count)] -= 1
# Update invalid indices in neighbor array since number of sites changed
num_sites = indices.shape[0]
neighbors[neighbors == invalid_idx] = num_sites
return indices, positions, neighbors, dists
[docs]class Lattice(LatticeStructure):
"""Main lattice object representing a Bravais lattice model.
Combines the ``LatticeBasis`` and the ``LatticeStructure`` class and adds
the ability to construct finite lattice models.
.. rubric:: Inheritance
.. inheritance-diagram:: Lattice
:parts: 1
Parameters
----------
basis: array_like or float or LatticeBasis
The primitive basis vectors that define the unit cell of the lattice. If a
``LatticeBasis`` instance is passed it is copied and used as the new basis
of the lattice.
**kwargs
Key-word arguments. Used for quickly configuring a ``Lattice`` instance.
Allowed keywords are:
Properties:
atoms: Dictionary containing the atoms to add to the lattice.
cons: Dictionary containing the connections to add to the lattice.
shape: int or tuple defining the shape of the finite size lattice to build.
periodic: int or list defining the periodic axes to set up.
Examples
--------
Two dimensional lattice with one atom in the unit cell and nearest neighbors
>>> import lattpy as lp
>>> latt = lp.Lattice(np.eye(2))
>>> latt.add_atom()
>>> latt.add_connections(1)
>>> _ = latt.build((5, 3))
>>> latt
Lattice(dim: 2, num_base: 1, num_neighbors: [4], shape: [5. 3.])
Quick-setup of the same lattice:
>>> import lattpy as lp
>>> import matplotlib.pyplot as plt
>>> latt = lp.Lattice.square(atoms={(0.0, 0.0): "A"}, cons={("A", "A"): 1})
>>> _ = latt.build((5, 3))
>>> _ = latt.plot()
>>> plt.show()
"""
def __init__(self, basis: basis_t, **kwargs):
super().__init__(basis, **kwargs)
# Lattice Cache
self.data = LatticeData()
self.shape = None
self.pos = None
self.periodic_axes = list()
self.primitive = None
if "shape" in kwargs:
self.build(kwargs["shape"], periodic=kwargs.get("periodic", None))
@property
def num_sites(self) -> int:
"""int: Number of sites in lattice data (if lattice has been built)."""
return self.data.num_sites
@property
def num_cells(self) -> int:
"""int: Number of unit-cells in lattice data (if lattice has been built)."""
return np.unique(self.data.indices[:, :-1], axis=0).shape[0]
@property
def indices(self):
"""np.ndarray: The lattice indices of the cached lattice data."""
return self.data.indices
@property
def positions(self):
"""np.ndarray: The lattice positions of the cached lattice data."""
return self.data.positions
[docs] def volume(self) -> float:
"""The total volume (number of cells x cell-volume) of the built lattice.
Returns
-------
vol : float
The volume of the finite lattice structure.
"""
return self.cell_volume * np.unique(self.data.indices[:, :-1], axis=0).shape[0]
[docs] def alpha(self, idx: int) -> int:
"""Returns the atom component of the lattice index for a site in the lattice.
Parameters
----------
idx : int
The super-index of a site in the cached lattice data.
Returns
-------
alpha : int
The index of the atom in the unit cell.
"""
return self.data.indices[idx, -1]
[docs] def atom(self, idx: int) -> Atom:
"""Returns the atom of a given site in the cached lattice data.
Parameters
----------
idx : int
The super-index of a site in the cached lattice data.
Returns
-------
atom : Atom
"""
return self._atoms[self.data.indices[idx, -1]]
[docs] def position(self, idx: int) -> np.ndarray:
"""Returns the position of a given site in the cached lattice data.
Parameters
----------
idx : int
The super-index of a site in the cached lattice data.
Returns
-------
pos : (D, ) np.ndarray
The position of the lattice site.
"""
return self.data.positions[idx]
[docs] def limits(self) -> np.ndarray:
"""Returns the spatial limits of the lattice model.
Returns
-------
limits : (2, D) np.ndarray
An array of the limits of the lattice positions. The first axis contains
the minimum position and the second the maximum position for all dimensions.
"""
pos = self.positions
return np.array([np.min(pos, axis=0), np.max(pos, axis=0)])
[docs] def relative_position(self, fractions) -> np.ndarray:
"""Computes a relitve position in the lattice model.
Parameters
----------
fractions : float or (D, ) array_like
The position relatice to the size of the lattice. The center of the lattice
is returned for the fractions `[0.5, ..., 0.5]`.
Returns
-------
relpos : (D, ) np.ndarray
The relative position in the lattice model.
"""
limits = self.limits()
return limits[0] + np.asanyarray(fractions) * (limits[1] - limits[0])
[docs] def center(self) -> np.ndarray:
"""Returns the spatial center of the lattice.
Returns
-------
center : (D, ) np.ndarray
The position of the spatial center of the lattice.
"""
frac = np.ones(self.dim) * 0.5
return self.relative_position(frac)
[docs] def center_of_gravity(self) -> np.ndarray:
"""Computes the center of gravity of the lattice model.
Returns
-------
center : (D, ) np.ndarray
The center of gravity of the lattice model.
Notes
-----
Requires that the `mass` attribute is set for each atom. If no mass is set
the default mass of `1.0` is used.
"""
masses = [self.atom(i).get("mass", 1.0) for i in range(self.num_sites)]
masses = np.asarray(masses)
center = np.sum(masses[:, None] * self.positions, axis=0) / np.sum(masses)
return center
[docs] def superindex_from_pos(self, pos: ArrayLike, atol: float = 1e-4) -> Optional[int]:
"""Returns the super-index of a given position.
Parameters
----------
pos : (D, ) array_like
The position of the site in cartesian coordinates.
atol : float, optional
The absolute tolerance for comparing positions.
Returns
-------
index : int or None
The super-index of the site in the cached lattice data.
"""
diff = self.data.positions - np.array(pos)[None, :]
indices = np.where(np.all(np.abs(diff) < atol, axis=1))[0]
if len(indices) == 0:
return None
return indices[0]
[docs] def superindex_from_index(self, ind: ArrayLike) -> Optional[int]:
"""Returns the super-index of a site defined by the lattice index.
Parameters
----------
ind : (D + 1, ) array_like
The lattice index ``(n_1, ..., n_D, alpha)`` of the site.
Returns
-------
index : int or None
The super-index of the site in the cached lattice data.
"""
diff = self.data.indices - np.array(ind)[None, :]
indices = np.where(np.all(np.abs(diff) < 1e-4, axis=1))[0]
if len(indices) == 0:
return None
return indices[0]
[docs] def neighbors(
self, site: int, distidx: int = None, unique: bool = False
) -> np.ndarray:
"""Returns the neighours of a given site in the cached lattice data.
Parameters
----------
site : int
The super-index of a site in the cached lattice data.
distidx : int, optional
Index of distance to the neighbors, default is 0 (nearest neighbors).
unique : bool, optional
If True, each unique pair is only returned once.
Returns
-------
indices : np.ndarray of int
The super-indices of the neighbors.
"""
return self.data.get_neighbors(site, distidx, unique=unique)
[docs] def nearest_neighbors(self, idx: int, unique: bool = False) -> np.ndarray:
"""Returns the nearest neighors of a given site in the cached lattice data.
Parameters
----------
idx : int
The super-index of a site in the cached lattice data.
unique : bool, optional
If True, each unique pair is only return once.
Returns
-------
indices : (N, ) np.ndarray of int
The super-indices of the nearest neighbors.
"""
return self.neighbors(idx, 0, unique)
[docs] def iter_neighbors(
self, site: int, unique: bool = False
) -> Iterator[Tuple[int, np.ndarray]]:
"""Iterates over the neighbors of all distances of a given site.
Parameters
----------
site : int
The super-index of a site in the cached lattice data.
unique : bool, optional
If True, each unique pair is only return once.
Yields
------
distidx : int
The distance index of the neighbor indices.
neighbors : (N, ) np.ndarray
The super-indices of the neighbors for the corresponding distance level.
"""
return self.data.iter_neighbors(site, unique)
[docs] def check_neighbors(self, idx0: int, idx1: int) -> Union[float, None]:
"""Checks if two sites are neighbors and returns the distance level if they are.
Parameters
----------
idx0 : int
The first super-index of a site in the cached lattice data.
idx1 : int
The second super-index of a site in the cached lattice data.
Returns
-------
distidx : int or None
The distance index of the two sites if they are neighbors.
"""
for distidx in range(self.num_distances):
if idx1 in self.neighbors(idx0, distidx):
return distidx
return None
def _update_shape(self):
limits = self.data.get_limits()
self.shape = limits[1] - limits[0]
self.pos = limits[0]
[docs] def build(
self,
shape: Union[float, Sequence[float], AbstractShape],
primitive: bool = False,
pos: Union[float, Sequence[float]] = None,
check: bool = True,
min_neighbors: int = None,
num_jobs: int = -1,
periodic: Union[bool, int, Sequence[int]] = None,
callback: Callable = None,
dtype: Union[int, str, np.dtype] = None,
):
"""Constructs the indices and neighbors of a finite size lattice.
Parameters
----------
shape : (N, ) array_like or float or AbstractShape
Shape of finite size lattice to build.
primitive : bool, optional
If True the shape will be multiplied by the cell size of the model.
The default is False.
pos : (N, ) array_like or int, optional
Optional position of the section to build. If ``None`` the origin is used.
check : bool, optional
If True the positions of the translation vectors are checked and
filtered. The default is True. This should only be disabled if
filtered later.
min_neighbors : int, optional
The minimum number of neighbors a site must have. This can be used to
remove dangling sites at the edge of the lattice.
num_jobs : int, optional
Number of jobs to schedule for parallel processing of neighbors.
If ``-1`` is given all processors are used. The default is ``-1``.
periodic : int or array_like, optional
Optional periodic axes to set. See ``set_periodic`` for mor details.
callback : callable, optional
The indices and positions are passed as arguments.
dtype : int or str or np.dtype, optional
Optional data-type for storing the lattice indices. Using a smaller
bit-size may help reduce memory usage. By default, the given limits are
checked to determine the smallest possible data-type.
Raises
------
ValueError
Raised if the dimension of the position doesn't match the dimension of
the lattice.
NoConnectionsError
Raised if no connections have been set up.
NotAnalyzedError
Raised if the lattice distances and base-neighbors haven't been computed.
"""
self.data.reset()
if not isinstance(shape, AbstractShape):
basis = self.vectors if primitive else None
shape = Shape(shape, pos=pos, basis=basis)
# shape = np.atleast_1d(shape)
self._assert_connections()
self._assert_analyzed()
logger.debug("Building lattice: %s at %s", shape, pos)
# Build indices and positions
indices, positions = self.build_indices(
shape, primitive, pos, check, callback, dtype, return_pos=True
)
# Compute the neighbors and distances between the sites
neighbors, distances_ = self.compute_neighbors(indices, positions, num_jobs)
if min_neighbors is not None:
data = _filter_dangling(
indices, positions, neighbors, distances_, min_neighbors
)
indices, positions, neighbors, distances_ = data
# Set data of the lattice and update shape
self.data.set(indices, positions, neighbors, distances_)
self.primitive = primitive
self._update_shape()
if periodic is not None:
self.set_periodic(periodic)
logger.debug(
"Lattice shape: %s (%s)",
self.shape,
frmt_num(self.data.nbytes, unit="iB", div=1024),
)
return shape
def _build_periodic_translation_vector(self, axes, primitive=False, indices=None):
if indices is None:
indices = self.indices.copy()
axes = np.atleast_1d(axes)
if not primitive:
# Get lattice points limits
indices = indices.copy()[:, :-1] # strip alpha
indices = np.unique(indices, axis=0)
positions = self.transform(indices)
limits = np.array([np.min(positions, axis=0), np.max(positions, axis=0)])
shape = limits[1] - limits[0]
# Get periodic point
ppoint = np.zeros(self.dim)
for ax in axes:
ppoint[ax] = shape[ax] + self.cell_size[ax] * 2 / 3
# Get periodic translation vector from point
pnvec = self.itransform(ppoint)
pnvec = np.round(pnvec, decimals=0)
else:
# Get index limits
limits = np.array([np.min(indices, axis=0), np.max(indices, axis=0)])
idx_size = (limits[1] - limits[0])[:-1]
# Get periodic translation vector from limits
pnvec = np.zeros_like(idx_size, dtype=np.int64)
for ax in axes:
pnvec[ax] = np.floor(idx_size[ax]) + 1
return pnvec.astype(np.int64)
[docs] def periodic_translation_vectors(self, axes, primitive=False):
"""Constrcuts all translation vectors for periodic boundary conditions.
Parameters
----------
axes : int or (N, ) array_like
One or multiple axises to compute the translation vectors for.
primitive : bool, optional
Flag if the specified axes are in cartesian or lattice coordinates.
If ``True`` the passed position will be multiplied with the lattice vectors.
The default is ``False`` (cartesian coordinates).
Returns
-------
nvecs : list of tuple
The translation vectors for the periodic boundary conditions.
The first item of each element is the axis, the second the
corresponding translation vector.
"""
# One axis: No combinations needed
if isinstance(axes, int) or len(axes) == 1:
return [(axes, self._build_periodic_translation_vector(axes, primitive))]
# Add all combinations of the periodic axis
items = list()
for ax in itertools.combinations_with_replacement(axes, r=2):
nvec = self._build_periodic_translation_vector(ax, primitive)
items.append((ax, nvec))
# Use +/- for every axis exept the first one to ensure all corners are hit
if not np.all(np.array(ax) == axes[0]):
nvec2 = np.copy(nvec)
nvec2[1:] *= -1
items.append((ax, nvec2))
return items
def _build_periodic(self, indices, positions, nvec, out_ind=None, out_pos=None):
delta_pos = self.translate(nvec)
delta_idx = np.append(nvec, 0)
if out_ind is not None and out_pos is not None:
out_ind[:] = indices + delta_idx
out_pos[:] = positions + delta_pos
else:
out_ind = indices + delta_idx
out_pos = positions + delta_pos
return out_ind, out_pos
[docs] def kdtree(self, positions=None, eps=0.0, boxsize=None):
if positions is None:
positions = self.data.positions
k = np.sum(np.sum(self._raw_num_neighbors, axis=1)) + 1
max_dist = np.max(self.distances) + 0.1 * np.min(self._raw_distance_matrix)
return KDTree(positions, k, max_dist, eps=eps, boxsize=boxsize)
def _compute_pneighbors(
self, axis, primitive=False, indices=None, positions=None, num_jobs=-1
):
if indices is None:
indices = self.data.indices
positions = self.data.positions
axis = np.atleast_1d(axis)
invald_idx = len(indices)
# Build tree
k = np.sum(np.sum(self._raw_num_neighbors, axis=1)) + 1
max_dist = np.max(self.distances) + 0.1 * np.min(self._raw_distance_matrix)
tree = KDTree(positions, k, max_dist)
# Initialize arrays
ind_t = np.zeros_like(indices)
pos_t = np.zeros_like(positions)
pidx, pdists, pnvecs, paxs = dict(), dict(), dict(), dict()
for ax, nvec in self.periodic_translation_vectors(axis, primitive):
# Translate positions along periodic axis
self._build_periodic(indices, positions, nvec, ind_t, pos_t)
# Query neighbors with translated points and filter
neighbors, distances_ = tree.query(pos_t, num_jobs, self.DIST_DECIMALS)
neighbors, distances_ = self._filter_neighbors(
indices, neighbors, distances_, ind_t
)
# Convert to dict
idx = np.where(np.isfinite(distances_).any(axis=1))[0]
distances_ = distances_[idx]
neighbors = neighbors[idx]
for i, site in enumerate(idx):
mask = i, neighbors[i] < invald_idx
inds = neighbors[mask]
dists = distances_[mask]
# Update dict for indices `inds`
pidx.setdefault(site, list()).extend(inds) # noqa
pdists.setdefault(site, list()).extend(dists)
paxs.setdefault(site, list()).extend([ax] * len(inds))
pnvecs.setdefault(site, list()).extend([nvec] * len(inds))
# Update dict for neighbor indices of `inds`
for j, d in zip(inds, dists):
pidx.setdefault(j, list()).append(site) # noqa
pdists.setdefault(j, list()).append(d)
paxs.setdefault(j, list()).append(ax)
pnvecs.setdefault(j, list()).append(-nvec)
# Convert values of dict to np.ndarray's
for k in list(pidx.keys()):
pi = list(pidx[k]) # indices of periodic neighbors
# Check if periodic neighbors is in regular neighbors
# This occurs for small lattices and makes no sense, since the
# sites are already neighbors
existing_neighbors = self.neighbors(k)
for _i, ind in enumerate(pi[:]):
if ind in existing_neighbors:
pi.remove(ind)
if len(pi):
# Convert to arrays
sites, ind = np.unique(pi, return_index=True)
pidx[k] = np.array(sites)
pdists[k] = np.array(pdists[k])[ind]
paxs[k] = np.array(paxs[k])[ind]
pnvecs[k] = np.array(pnvecs[k])[ind]
else:
# Remove periodic neighbors
del pidx[k]
del pdists[k]
del paxs[k]
del pnvecs[k]
return pidx, pdists, pnvecs, paxs
[docs] def set_periodic(
self, axis: Union[bool, int, Sequence[int]] = None, primitive: bool = None
):
"""Sets periodic boundary conditions along the given axis.
Parameters
----------
axis : bool or int or (N, ) array_like
One or multiple axises to apply the periodic boundary conditions.
If the axis is ``None`` the perodic boundary conditions will be removed.
primitive : bool, optional
Flag if the specified axes are in cartesian or lattice coordinates.
If ``True`` the passed position will be multiplied with the lattice vectors.
The default is ``False`` (cartesian coordinates).
.. deprecated:: 0.8.0
The `primitive` argument will be removed in lattpy 0.9.0
Raises
------
NotBuiltError
Raised if the lattice hasn't been built yet.
Notes
-----
The lattice has to be built before applying the periodic boundarie conditions.
The lattice also has to be at least three atoms big in the specified directions.
Uses the same coordinate system (cartesian or primtive basis vectors)
as chosen for building the lattice.
"""
if isinstance(axis, bool):
if axis is True:
axis = np.arange(self.dim)
else:
axis = None
if primitive is not None:
warnings.warn(
"The `primitive` argument is deprecated and will be removed in "
"lattpy 0.9.0. The value for building the lattice is reused!",
DeprecationWarning,
)
primitive = self.primitive
logger.debug("Computing periodic neighbors along axis %s", axis)
if self.shape is None:
raise NotBuiltError()
if axis is None:
self.data.remove_periodic()
self.periodic_axes = list()
else:
self.data.remove_periodic()
axis = np.atleast_1d(axis)
pidx, pdists, pnvecs, paxs = self._compute_pneighbors(axis, primitive)
if not pidx:
return
self.data.set_periodic(pidx, pdists, pnvecs, paxs)
self.periodic_axes = axis
def _compute_connection_neighbors(self, positions1, positions2):
# Set neighbor query parameters
k = np.sum(np.sum(self._raw_num_neighbors, axis=1)) + 1
max_dist = np.max(self.distances) + 0.1 * np.min(self._raw_distance_matrix)
# Build sub-lattice tree's
tree1 = KDTree(positions1, k=k, max_dist=max_dist)
tree2 = KDTree(positions2, k=k, max_dist=max_dist)
pairs = list()
distances_ = list()
# offset = len(positions1)
connections = tree1.query_ball_tree(tree2, max_dist)
for i, conns in enumerate(connections):
if conns:
conns = np.asarray(conns)
dists = cdist(np.asarray([positions1[i]]), positions2[conns])[0]
for j, dist in zip(conns, dists):
pairs.append((i, j))
# pairs.append((j, i))
distances_.append(dist)
# distances_.append(dist)
return np.array(pairs), np.array(distances_)
[docs] def compute_connections(self, latt):
"""Computes the connections between the current and another lattice.
Parameters
----------
latt : Lattice
The other lattice.
Returns
-------
neighbors : (N, 2) np.ndarray
The connecting pairs between the two lattices.
The first index of each row is the index in the current lattice data, the
second one is the index for the other lattice ``latt``.
distances : (N) np.ndarray
The corresponding distances for the connections.
"""
positions2 = latt.data.positions
return self._compute_connection_neighbors(self.data.positions, positions2)
[docs] def minimum_distances(self, site, primitive=None):
"""Computes the minimum distances between one site and the other lattice sites.
This method can be used to find the distances in a lattice with
periodic boundary conditions.
Parameters
----------
site : int
The super-index i of a site in the cached lattice data.
primitive : bool, optional
Flag if the periopdic boundarey conditions are set up along cartesian or
primitive basis vectors. The default is ``False`` (cartesian coordinates).
.. deprecated:: 0.8.0
The `primitive` argument will be removed in lattpy 0.9.0
Returns
-------
min_dists : (N, ) np.ndarray
The minimum distances between the lattice site i and the other sites.
Notes
-----
Uses the same coordinate system (cartesian or primtive basis vectors)
as chosen for building the lattice.
"""
if primitive is not None:
warnings.warn(
"The `primitive` argument is deprecated and will be removed in "
"lattpy 0.9.0. The value for building the lattice is reused!",
DeprecationWarning,
)
positions = self.positions
# normal distances
dists = [distances(positions[site], positions)]
# periodic distances (to translated site)
paxs = self.periodic_axes
for axs, vec in self.periodic_translation_vectors(paxs, self.primitive):
# Get position of translated lattice point and compute distances
translated = self.translate(vec, positions[site])
dists.append(distances(translated, positions))
# reverse translate direction
translated = self.translate(-vec, positions[site])
dists.append(distances(translated, positions))
# get minimum distances
return np.min(dists, axis=0)
def _append(
self,
ind,
pos,
neighbors,
dists,
ax=0,
side=+1,
sort_axis=None,
sort_reverse=False,
primitive=False,
):
indices2 = np.copy(ind)
positions2 = np.copy(pos)
neighbors2 = np.copy(neighbors)
distances2 = np.copy(dists)
# Build translation vector
indices = self.data.indices if side > 0 else indices2
nvec = self._build_periodic_translation_vector(ax, primitive, indices)
if side <= 0:
nvec = -1 * nvec
vec = self.translate(nvec)
# Store temporary data
positions1 = self.data.positions
# Shift data of appended lattice
indices2[:, :-1] += nvec
positions2 += vec
# Append data and compute connecting neighbors
self.data.append(indices2, positions2, neighbors2, distances2)
pairs, distances_ = self._compute_connection_neighbors(positions1, positions2)
offset = len(positions1)
for (i, j), dist in zip(pairs, distances_):
self.data.add_neighbors(i, j + offset, dist)
self.data.add_neighbors(j + offset, i, dist)
if sort_axis is not None:
self.data.sort(sort_axis, reverse=sort_reverse)
# Update the shape of the lattice
self._update_shape()
# noinspection PyShadowingNames
[docs] def append(
self, latt, ax=0, side=+1, sort_ax=None, sort_reverse=False, primitive=None
):
"""Append another `Lattice`-instance along an axis.
Parameters
----------
latt : Lattice
The other lattice to append to this instance.
ax : int, optional
The axis along the other lattice is appended. The default is 0 (x-axis).
side : int, optional
The side at which the new lattice is appended. If, for example, axis 0
is used, the other lattice is appended on the right side if ``side=+1``
and on the left side if ``side=-1``.
sort_ax : int, optional
The axis to sort the lattice indices after the other lattice has been
added. The default is the value specified for ``ax``.
sort_reverse : bool, optional
If True, the lattice indices are sorted in reverse order.
primitive : bool, optional
Flag if the periopdic boundarey conditions are set up along cartesian or
primitive basis vectors. The default is ``False`` (cartesian coordinates).
.. deprecated:: 0.8.0
The `primitive` argument will be removed in lattpy 0.9.0
Notes
-----
Uses the same coordinate system (cartesian or primtive basis vectors)
as chosen for building the lattice.
Examples
--------
>>> latt = Lattice(np.eye(2))
>>> latt.add_atom(neighbors=1)
>>> latt.build((5, 2))
>>> latt.shape
[5. 2.]
>>> latt2 = Lattice(np.eye(2))
>>> latt2.add_atom(neighbors=1)
>>> latt2.build((2, 2))
>>> latt2.shape
[2. 2.]
>>> latt.append(latt2, ax=0)
>>> latt.shape
[8. 2.]
"""
if primitive is not None:
warnings.warn(
"The `primitive` argument is deprecated and will be removed in "
"lattpy 0.9.0. The value for building the lattice is reused!",
DeprecationWarning,
)
ind = latt.data.indices
pos = latt.data.positions
neighbors = latt.data.neighbors
dists = latt.data.distvals[latt.data.distances]
self._append(
ind, pos, neighbors, dists, ax, side, sort_ax, sort_reverse, self.primitive
)
[docs] def extend(self, size, ax=0, side=1, num_jobs=1, sort_ax=None, sort_reverse=False):
"""Extend the lattice along an axis.
Parameters
----------
size : float
The size of which the lattice will be extended in direction of ``ax``.
ax : int, optional
The axis along the lattice is extended. The default is 0 (x-axis).
side : int, optional
The side at which the new lattice is appended. If, for example, axis 0
is used, the lattice is extended to the right side if ``side=+1``
and to the left side if ``side=-1``.
num_jobs : int, optional
Number of jobs to schedule for parallel processing of neighbors for new
sites. If ``-1`` is given all processors are used. The default is ``-1``.
sort_ax : int, optional
The axis to sort the lattice indices after the lattice has been extended.
The default is the value specified for ``ax``.
sort_reverse : bool, optional
If True, the lattice indices are sorted in reverse order.
Examples
--------
>>> latt = Lattice(np.eye(2))
>>> latt.add_atom(neighbors=1)
>>> latt.build((5, 2))
>>> latt.shape
[5. 2.]
>>> latt.extend(2, ax=0)
[8. 2.]
>>> latt.extend(2, ax=1)
[8. 5.]
"""
# Build indices and positions of new section
shape = np.copy(self.shape)
shape[ax] = size
ind, pos = self.build_indices(shape, primitive=self.primitive, return_pos=True)
# Compute the neighbors and distances between the sites of new section
neighbors, dists = self.compute_neighbors(ind, pos, num_jobs)
# Append new section
self._append(
ind, pos, neighbors, dists, ax, side, sort_ax, sort_reverse, self.primitive
)
[docs] def repeat(self, num=1, ax=0, side=1, sort_ax=None, sort_reverse=False):
"""Repeat the lattice along an axis.
Parameters
----------
num : int
The number of times the lattice will be repeated in direction ``ax``.
ax : int, optional
The axis along the lattice is extended. The default is 0 (x-axis).
side : int, optional
The side at which the new lattice is appended. If, for example, axis 0
is used, the lattice is extended to the right side if ``side=+1``
and to the left side if ``side=-1``.
sort_ax : int, optional
The axis to sort the lattice indices after the lattice has been extended.
The default is the value specified for ``ax``.
sort_reverse : bool, optional
If True, the lattice indices are sorted in reverse order.
Examples
--------
>>> latt = Lattice(np.eye(2))
>>> latt.add_atom(neighbors=1)
>>> latt.build((5, 2))
>>> latt.shape
[5. 2.]
>>> latt.repeat()
[11. 2.]
>>> latt.repeat(3)
[35. 2.]
>>> latt.repeat(ax=1)
[35. 5.]
"""
ind = self.data.indices
pos = self.data.positions
neighbors = self.data.neighbors
dists = self.data.distvals[self.data.distances]
prim = self.primitive
for _ in range(num):
self._append(
ind, pos, neighbors, dists, ax, side, sort_ax, sort_reverse, prim
)
[docs] def dmap(self) -> DataMap:
"""DataMap : Returns the data-map of the lattice model."""
return self.data.map()
[docs] def neighbor_pairs(self, unique=False):
"""Returns all neighbor pairs with their corresponding distances in the lattice.
Parameters
----------
unique : bool, optional
If True, only unique pairs with i < j are returned. The default is False.
Returns
-------
pairs : (N, 2) np.ndarray
An array containing all neighbor pairs of the lattice. If `unique=True`,
the first index is always smaller than the second index in each element.
distindices : (N, ) np.ndarray
The corresponding distance indices of the neighbor pairs.
Examples
--------
>>> latt = Lattice.chain()
>>> latt.add_atom(neighbors=1)
>>> latt.build(5)
>>> idx, distidx = latt.neighbor_pairs()
>>> idx
array([[0, 1],
[1, 2],
[1, 0],
[2, 3],
[2, 1],
[3, 2]], dtype=uint8)
>>> distidx
array([0, 0, 0, 0, 0, 0], dtype=uint8)
>>> idx, distidx = latt.neighbor_pairs(unique=True)
>>> idx
array([[0, 1],
[1, 2],
[2, 3]], dtype=uint8)
"""
# Build index pairs and corresponding distance array
dtype = np.min_scalar_type(self.num_sites)
sites = np.arange(self.num_sites, dtype=dtype)
sites_t = np.tile(sites, (self.data.neighbors.shape[1], 1)).T
pairs = np.reshape([sites_t, self.data.neighbors], newshape=(2, -1)).T
distindices = self.data.distances.flatten()
# Filter pairs with invalid indices
mask = distindices != self.data.invalid_distidx
pairs = pairs[mask]
distindices = distindices[mask]
if unique:
# Filter for unique pairs (i < j)
mask = pairs[:, 0] < pairs[:, 1]
pairs = pairs[mask]
distindices = distindices[mask]
return pairs, distindices
[docs] def adjacency_matrix(self):
"""Computes the adjacency matrix for the neighbor data of the lattice.
Returns
-------
adj_mat : (N, N) csr_matrix
The adjacency matrix of the lattice.
See Also
--------
neighbor_pairs : Generates a list of neighbor indices.
Examples
--------
>>> latt = Lattice.chain()
>>> latt.add_atom(neighbors=1)
>>> latt.build(5)
>>> adj_mat = latt.adjacency_matrix()
>>> adj_mat.toarray()
array([[0, 1, 0, 0],
[1, 0, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 0]], dtype=int8)
"""
pairs, distindices = self.neighbor_pairs(unique=False)
rows, cols = pairs.T # noqa
data = distindices + 1
return csr_matrix((data, (rows, cols)), dtype=np.int8)
# ==================================================================================
[docs] def copy(self) -> "Lattice":
"""Lattice : Creates a (deep) copy of the lattice instance."""
return deepcopy(self)
[docs] def todict(self) -> dict:
"""Creates a dictionary containing the information of the lattice instance.
Returns
-------
d : dict
The information defining the current instance.
"""
d = super().todict()
d["shape"] = self.shape
return d
[docs] def dumps(self): # pragma: no cover
"""Creates a string containing the information of the lattice instance.
Returns
-------
s : str
The information defining the current instance.
"""
lines = list()
for key, values in self.todict().items():
head = key + ":"
lines.append(f"{head:<15}" + "; ".join(str(x) for x in values))
return "\n".join(lines)
[docs] def dump(self, file: Union[str, int, bytes]) -> None: # pragma: no cover
"""Save the data of the ``Lattice`` instance.
Parameters
----------
file : str or int or bytes
File name to store the lattice. If ``None`` the hash of the lattice is used.
"""
if file is None:
file = f"{self.__hash__()}.latt"
with open(file, "wb") as f:
pickle.dump(self, f)
[docs] @classmethod
def load(cls, file: Union[str, int, bytes]) -> "Lattice": # pragma: no cover
"""Load data of a saved ``Lattice`` instance.
Parameters
----------
file : str or int or bytes
File name to load the lattice.
Returns
-------
latt : Lattice
The lattice restored from the file content.
"""
with open(file, "rb") as f:
latt = pickle.load(f)
return latt
def __hash__(self):
import hashlib
sha = hashlib.md5(self.dumps().encode("utf-8"))
return int(sha.hexdigest(), 16)
def __eq__(self, other):
return self.__hash__() == other.__hash__()
[docs] def plot(
self,
lw: float = None,
margins: Union[Sequence[float], float] = 0.1,
legend: bool = None,
grid: bool = False,
pscale: float = 0.5,
show_periodic: bool = True,
show_indices: bool = False,
index_offset: float = 0.1,
con_colors: Sequence = None,
adjustable: str = "box",
ax: Union[plt.Axes, Axes3D] = None,
show: bool = False,
) -> Union[plt.Axes, Axes3D]: # pragma: no cover
"""Plot the cached lattice.
Parameters
----------
lw : float, optional
Line width of the neighbor connections.
margins : Sequence[float] or float, optional
The margins of the plot.
legend : bool, optional
Flag if legend is shown
grid : bool, optional
If True, draw a grid in the plot.
pscale : float, optional
The scale for drawing periodic connections. The default is half of the
normal length.
show_periodic : bool, optional
If True the periodic connections will be shown.
show_indices : bool, optional
If True the index of the sites will be shown.
index_offset : float, optional
The positional offset of the index text labels. Only used if
`show_indices=True`.
con_colors : Sequence[tuple], optional
list of colors to override the defautl connection color. Each element
has to be a tuple with the first two elements being the atom indices of
the pair and the third element the color, for example ``[(0, 0, 'r')]``.
adjustable : None or {'box', 'datalim'}, optional
If not None, this defines which parameter will be adjusted to meet
the equal aspect ratio. If 'box', change the physical dimensions of
the Axes. If 'datalim', change the x or y data limits.
Only applied to 2D plots.
ax : plt.Axes or plt.Axes3D or None, optional
Parent plot. If None, a new plot is initialized.
show : bool, optional
If True, show the resulting plot.
"""
logger.debug("Plotting lattice")
if self.dim > 3:
raise ValueError(f"Plotting in {self.dim} dimensions is not supported!")
hopz, atomz = range(2)
fig, ax = subplot(self.dim, adjustable, ax=ax)
# Draw sites
for alpha in range(self.num_base):
atom = self.atoms[alpha]
col = atom.color or f"C{alpha}"
points = self.data.get_positions(alpha)
label = atom.name
draw_sites(ax, points, atom.radius, color=col, label=label, zorder=atomz)
# Draw connections
ccolor = "k"
pcolor = "0.5"
positions = self.positions
hop_colors = connection_color_array(self.num_base, ccolor, con_colors)
per_colors = connection_color_array(self.num_base, pcolor)
for i in range(self.num_sites):
at1 = self.alpha(i)
p1 = positions[i]
for j in self.data.get_neighbors(i, periodic=False, unique=True):
p2 = positions[j]
at2 = self.alpha(j)
color = hop_colors[at1][at2]
draw_vectors(ax, p2 - p1, p1, color=color, lw=lw, zorder=hopz)
if show_periodic:
mask = self.data.neighbor_mask(i, periodic=True)
idx = self.data.neighbors[i, mask]
pnvecs = self.data.pnvecs[i, mask]
neighbor_pos = self.data.positions[idx]
for j, x in enumerate(neighbor_pos):
at2 = self.alpha(idx[j])
x = self.translate(-pnvecs[j], x)
color = per_colors[at1][at2]
vec = pscale * (x - p1)
draw_vectors(ax, vec, p1, color=color, lw=lw, zorder=hopz)
# Add index labels
if show_indices:
positions = [self.position(i) for i in range(self.num_sites)]
draw_indices(ax, positions, index_offset)
# Configure legend
if legend is None:
legend = self.num_base > 1
if legend:
ax.legend()
# Configure grid
if grid and self.dim < 3:
ax.set_axisbelow(True)
ax.grid(b=True, which="major")
# Adjust margin
if isinstance(margins, float):
margins = [margins] * self.dim
ax.margins(*margins)
fig.tight_layout()
if show:
plt.show()
return ax
def __repr__(self) -> str:
shape = str(self.shape) if self.shape is not None else "None"
return (
f"{self.__class__.__name__}("
f"dim: {self.dim}, "
f"num_base: {self.num_base}, "
f"num_neighbors: {self.num_neighbors}, "
f"shape: {shape})"
)