"""Module implementing a Structure class."""
# Standard library imports
import copy
from typing import List, Union
from collections.abc import Callable
# Third party library imports
import numpy as np
from ase import Atoms
import aiida
except ImportError:
aiida = None
import pymatgen
except ImportError:
pymatgen = None
# Internal library imports
from aim2dat.ext_interfaces import _return_ext_interface_modules
from aim2dat.strct.strct_io import get_structure_from_file
from aim2dat.io import zeo
from aim2dat.strct.strct_validation import (
from aim2dat.strct.mixin import AnalysisMixin, ManipulationMixin
import aim2dat.utils.chem_formula as utils_cf
import aim2dat.utils.print as utils_pr
from aim2dat.utils.maths import calc_angle
def _compare_function_args(args1, args2):
"""Compare function arguments to check if a property needs to be recalculated."""
for kwarg, value1 in args1.items():
if value1 != args2[kwarg]:
return False
return True
def _create_index_dict(value):
index_dict = {}
for idx, val in enumerate(value):
if val in index_dict:
index_dict[val] = [idx]
return index_dict
def _check_calculated_properties(structure, func, func_args):
property_name = "_".join(func.__name__.split("_")[1:])
if structure.store_calculated_properties and property_name in structure._function_args:
if _compare_function_args(structure._function_args[property_name], func_args):
return structure.extras[property_name]
calc_attr, calc_extra = func(structure, **func_args)
if calc_attr is not None:
structure.set_attribute(property_name, calc_attr)
if structure.store_calculated_properties:
if calc_extra is not None:
structure._extras[property_name] = calc_extra
structure._function_args[property_name] = func_args
return calc_extra
def import_method(func):
"""Mark function as import function."""
func._is_import_method = True
return func
def export_method(func):
"""Mark function as export function."""
func._is_export_method = True
return func
class Structure(AnalysisMixin, ManipulationMixin):
Represents a structure and contains methods to calculate properties of a structure
(molecule or crystal) or to manipulate a structure.
def __init__(
elements: List[str],
positions: List[List[float]],
pbc: List[bool],
is_cartesian: bool = True,
wrap: bool = False,
cell: List[List[float]] = None,
kinds: List[str] = None,
label: str = None,
site_attributes: dict = None,
store_calculated_properties: bool = True,
attributes: dict = None,
extras: dict = None,
function_args: dict = None,
"""Initialize object."""
self._inverse_cell = None
self._site_attributes = {}
self.elements = elements
self.kinds = kinds
self.cell = cell
self.pbc = pbc
self.label = label
self.site_attributes = site_attributes
self.store_calculated_properties = store_calculated_properties
self._attributes = {} if attributes is None else attributes
self._extras = {} if extras is None else extras
self._function_args = {} if function_args is None else function_args
self.set_positions(positions, is_cartesian=is_cartesian, wrap=wrap)
def __str__(self):
"""Represent object as string."""
def _parse_vector(vector):
vector = ["{0:.4f}".format(val) for val in vector]
return "[" + " ".join([" ".join([""] * (9 - len(val))) + val for val in vector]) + "]"
output_str = utils_pr._print_title(f"Structure: {self.label}") + "\n\n"
output_str += " Formula: " + utils_cf.transform_dict_to_str(self.chem_formula) + "\n"
output_str += " PBC: [" + " ".join(str(val) for val in self.pbc) + "]\n\n"
if self.cell is not None:
output_str += utils_pr._print_subtitle("Cell") + "\n"
# output_str += utils_pr._print_subtitle("Cell")
output_str += utils_pr._print_list(
"Vectors:", [_parse_vector(val) for val in self.cell]
output_str += " Lengths: " + _parse_vector(self.cell_lengths) + "\n"
output_str += " Angles: " + _parse_vector(self.cell_angles) + "\n"
output_str += " Volume: {0:.4f}\n\n".format(self.cell_volume)
output_str += utils_pr._print_subtitle("Sites") + "\n"
sites_list = []
for el, kind, cart_pos, scaled_pos in self.iter_sites(
get_kind=True, get_scaled_pos=True, get_cart_pos=True
site_str = f"{el} " + " ".join([""] * (3 - len(el)))
site_str += (
f"{kind} " + " ".join([""] * (6 - len(str(kind)))) + _parse_vector(cart_pos)
if scaled_pos is not None:
site_str += " " + _parse_vector(scaled_pos)
output_str += utils_pr._print_list("", sites_list)
output_str += utils_pr._print_hline()
return output_str
def __len__(self):
"""int: Get number of sites."""
return len(self.elements)
def __iter__(self):
"""Iterate through element and cartesian position."""
for el, pos in zip(self.elements, self.positions):
yield el, pos
def __contains__(self, key: str):
"""Check whether Structure contains the key."""
keys_to_check = [
d for d in dir(self) if not callable(getattr(self, d)) and not d.startswith("_")
return key in keys_to_check # key in self.__dict__.keys()
def __getitem__(self, key: str):
"""Return structure property by key or list of keys."""
if isinstance(key, list):
return {k: getattr(self, k) for k in key}
except AttributeError:
raise KeyError(f"Key `{key} is not present.")
elif isinstance(key, str):
return getattr(self, key)
except AttributeError:
raise KeyError(f"Key `{key} is not present.")
def __deepcopy__(self, memo):
"""Create a deepcopy of the object."""
copy = Structure(
memo[id(self)] = copy
return copy
def keys(self) -> list:
"""Return property names to create the structure."""
return [
def copy(self) -> "Structure":
"""Return copy of `Structure` object."""
return copy.deepcopy(self)
def get(self, key, value=None):
"""Get attribute by key and return default if not present."""
if self[key] is None:
return value
return self[key]
except KeyError:
return value
def label(self) -> Union[str, None]:
"""Return label of the structure (especially relevant in StructureCollection)."""
return self._label
def label(self, value):
if value is not None and not isinstance(value, str):
raise TypeError("`label` needs to be of type str.")
self._label = value
def elements(self) -> tuple:
"""Return the elements of the structure."""
return self._elements
def elements(self, value: Union[tuple, list, np.ndarray]):
elements = _structure_validate_elements(value)
if self.positions is not None and len(self.positions) != len(elements):
raise ValueError("Length of `elements` is unequal to length of `positions`.")
self._elements = elements
self._element_dict = _create_index_dict(elements)
self._chem_formula = utils_cf.transform_list_to_dict(elements)
def chem_formula(self) -> dict:
Return chemical formula.
return self._chem_formula
def positions(self) -> tuple:
"""tuple: Return the cartesian positions of the structure."""
return getattr(self, "_positions", None)
def scaled_positions(self) -> Union[tuple, None]:
"""tuple or None: Return the scaled positions of the structure."""
return getattr(self, "_scaled_positions", None)
def pbc(self) -> tuple:
"""Return the pbc of the structure."""
return self._pbc
def pbc(self, value: Union[tuple, list, np.ndarray, bool]):
if isinstance(value, (list, tuple, np.ndarray)):
if len(value) == 3 and all(isinstance(pbc0, (bool, np.bool_)) for pbc0 in value):
value = tuple([bool(pbc0) for pbc0 in value])
raise ValueError("`pbc` must have a length of 3 and consist of boolean variables.")
if isinstance(value, (bool, np.bool_)):
value = tuple([bool(value), bool(value), bool(value)])
raise TypeError("`pbc` must be a list, tuple or a boolean.")
if any(val for val in value) and self.cell is None:
raise ValueError(
"`cell` must be set if `pbc` is set to true for one or more direction."
self._pbc = value
def cell(self) -> Union[tuple, None]:
"""Return the cell of the structure."""
return getattr(self, "_cell", None)
def cell(self, value: Union[tuple, list, np.ndarray]):
if value is not None:
self._cell, self._inverse_cell = _structure_validate_cell(value)
self._cell_volume = abs(np.dot(np.cross(self._cell[0], self._cell[1]), self._cell[2]))
self._cell_lengths = tuple([float(np.linalg.norm(vec)) for vec in self._cell])
self._cell_angles = tuple(
float(calc_angle(self._cell[i1], self._cell[i2]) * 180.0 / np.pi)
for i1, i2 in [(1, 2), (0, 2), (0, 1)]
# if hasattr(self, "_positions"):
# self.set_positions(self.positions, is_cartesian=True)
def cell_volume(self) -> Union[float, None]:
"""tuple: cell volume."""
return getattr(self, "_cell_volume", None)
def cell_lengths(self) -> Union[tuple, None]:
"""tuple: cell lengths."""
return getattr(self, "_cell_lengths", None)
def cell_angles(self) -> Union[tuple, None]:
"""tuple: Cell angles."""
return getattr(self, "_cell_angles", None)
def kinds(self) -> Union[tuple, None]:
"""tuple: Kinds of the structure."""
return self._kinds
def kinds(self, value: Union[tuple, list]):
if value is None:
value = [None] * len(self.elements)
if not isinstance(value, (list, tuple)):
raise TypeError("`kinds` must be a list or tuple.")
if len(value) != len(self.elements):
raise ValueError("`kinds` must have the same length as `elements`.")
self._kind_dict = _create_index_dict(value)
self._kinds = tuple(value)
def site_attributes(self) -> Union[dict, None]:
dict :
Dictionary containing the label of a site attribute as key and a tuple/list of values
having the same length as the ``Structure`` object itself (number of sites) containing
site specific properties or attributes (e.g. charges, magnetic moments, forces, ...).
return copy.deepcopy(self._site_attributes)
def site_attributes(self, value: dict):
if value is None:
value = {}
for key, val in value.items():
self.set_site_attribute(key, val)
def function_args(self) -> dict:
"""Return function arguments for stored extras."""
return copy.deepcopy(self._function_args)
def attributes(self) -> dict:
"""Return attributes."""
return copy.deepcopy(self._attributes)
def extras(self) -> dict:
Return extras.
return copy.deepcopy(self._extras)
def store_calculated_properties(self) -> bool:
Store calculated properties to reuse them later.
return self._store_calculated_properties
def store_calculated_properties(self, value: bool):
if not isinstance(value, bool):
raise TypeError("`store_calculated_properties` needs to be of type bool.")
self._store_calculated_properties = value
def iter_sites(
get_kind: bool = False,
get_cart_pos: bool = False,
get_scaled_pos: bool = False,
wrap: bool = False,
site_attributes: Union[str, list] = [],
Iterate through the sites of the structure.
get_kind : bool (optional)
Include kind in tuple.
get_cart_pos : bool (optional)
Include cartesian position in tuple.
get_scaled_pos : bool (optional)
Include scaled position in tuple.
wrap : bool (optional)
Wrap atomic positions back into the unit cell.
site_attributes : list (optional)
Include site attributes defined by their label.
str or tuple
Either element symbol or tuple containing the element symbol, kind string,
cartesian position, scaled position or specified site attributes.
if isinstance(site_attributes, str):
site_attributes = [site_attributes]
site_attr_dict = {} if self.site_attributes is None else self.site_attributes
for idx, el in enumerate(self.elements):
output = [el]
if get_kind:
pos_cart = self.positions[idx]
pos_scaled = None if self.scaled_positions is None else self.scaled_positions[idx]
if (get_cart_pos or get_scaled_pos) and wrap:
pos_cart, pos_scaled = self._wrap_position(pos_cart, pos_scaled)
if get_cart_pos:
if get_scaled_pos:
for site_attr in site_attributes:
if len(output) == 1:
yield el
yield tuple(output)
def set_positions(
self, positions: Union[list, tuple], is_cartesian: bool = True, wrap: bool = False
Set postions of atoms.
positions : list or tuple
Nested list or tuple of the coordinates (n atoms x 3).
is_cartesian : bool (optional)
Whether the coordinates are cartesian or scaled.
wrap : bool (optional)
Wrap atomic positions into the unit cell.
if len(self.elements) != len(positions):
raise ValueError("`elements` and `positions` must have the same length.")
self._positions, self._scaled_positions = _structure_validate_positions(
positions, is_cartesian, self.cell, self._inverse_cell, self.pbc
if wrap:
new_positions = [
pos for pos in self.iter_sites(get_cart_pos=True, get_scaled_pos=True, wrap=wrap)
_, cart_positions, scaled_positions = zip(*new_positions)
self._positions = tuple(cart_positions)
self._scaled_positions = tuple(scaled_positions)
def get_positions(self, cartesian: bool = True, wrap: bool = False):
Return positions of atoms.
cartesian : bool (optional)
Get cartesian positions. If set to ``False`` scaled positions are returned.
wrap : bool (optional)
Wrap atomic positions into the unit cell.
return tuple(
for _, pos in self.iter_sites(
get_cart_pos=cartesian, get_scaled_pos=not cartesian, wrap=wrap
def set_attribute(self, key: str, value):
Set attribute.
key : str
Key of the attribute.
value :
Value of the attribute.
self._attributes[key] = value
def set_site_attribute(self, key: str, values: Union[list, tuple]):
Set site attribute.
key : str
Key of the site attribute.
values :
Values of the attribute, need to have the same length as the ``Structure`` object
itself (number of sites).
if not isinstance(values, (list, tuple)):
raise TypeError(f"Value of site property `{key}` must be a list or tuple.")
if len(values) != len(self.elements):
raise ValueError(
f"Value of site property `{key}` must have the same length as `elements`."
self._site_attributes[key] = tuple(values)
def import_methods(cls) -> list:
"""list: Return import methods."""
import_methods = []
for name, method in cls.__dict__.items():
if getattr(method, "_is_import_method", False):
return import_methods
def export_methods(cls) -> list:
"""list: Return export methods."""
export_methods = []
for name, method in Structure.__dict__.items():
if getattr(method, "_is_export_method", False):
return export_methods
def from_file(
file_path: str,
attributes: dict = None,
label: str = None,
backend: str = "ase",
file_format: str = None,
backend_kwargs: dict = None,
) -> "Structure":
Get structure from file using the ase read-function.
file_path : str
File path.
attributes : dict
Attributes stored within the structure object(s).
label : str
Label used internally to store the structure in the object.
backend : str (optional)
Backend to be used to parse the structure file. Supported options are ``'ase'``
and ``'internal'``.
file_format : str or None (optional)
File format of the backend. For ``'ase'``, please refer to the documentation of the
package for a complete list. For ``'internal'``, the format translates from
``io.{module}.read_structure`` to ``'{module}'`` or from
``{module}.read_{specification}_structure`` to ``'module-specification'``. If set to
``None`` the corresponding function is searched based on the file name and suffix.
backend_kwargs : dict (optional)
Arguments passed to the backend function.
backend_kwargs = {} if backend_kwargs is None else backend_kwargs
if backend == "ase":
backend_module = _return_ext_interface_modules("ase_atoms")
if "format" not in backend_kwargs:
backend_kwargs["format"] = file_format
structure_dicts = backend_module._load_structure_from_file(file_path, backend_kwargs)
elif backend == "internal":
structure_dicts = get_structure_from_file(file_path, file_format, backend_kwargs)
raise ValueError(f"Backend '{backend}' is not supported.")
if isinstance(structure_dicts, dict):
structure_dicts = [structure_dicts]
if len(structure_dicts) == 1:
if label is not None:
structure_dicts[0]["label"] = label
strct = cls(**structure_dicts[0], attributes=attributes)
strct = []
for idx, structure_dict in enumerate(structure_dicts):
if label is not None:
structure_dict["label"] = label + f"_{idx}"
strct.append(cls(**structure_dict, attributes=copy.deepcopy(attributes)))
return strct
def from_ase_atoms(
cls, ase_atoms: Atoms, attributes: dict = None, label: str = None
) -> "Structure":
Get structure from ase atoms object.
ase_atoms : ase.Atoms
ase Atoms object.
attributes : dict
Attributes stored within the structure object.
label : str
Label used internally to store the structure in the object.
backend_module = _return_ext_interface_modules("ase_atoms")
return cls(
def from_pymatgen_structure(
pymatgen_structure: Union["pymatgen.core.Molecule", "pymatgen.core.Structure"],
attributes: dict = None,
label: str = None,
) -> "Structure":
Get structure from pymatgen structure or molecule object.
pymatgen_structure : pymatgen.core.Structure or pymatgen.core.Molecule
pymatgen structure or molecule object.
attributes : dict
Additional information about the structure.
label : str
Label used internally to store the structure in the object.
backend_module = _return_ext_interface_modules("pymatgen")
return cls(
def from_aiida_structuredata(
structure_node: Union[int, str, "aiida.orm.StructureData"],
use_uuid: bool = False,
label: str = None,
) -> "Structure":
Append structure from AiiDA structure node.
label : str
Label used internally to store the structure in the object.
structure_node : int, str or aiida.orm.nodes.data.structure.StructureData
Primary key, UUID or AiiDA structure node.
use_uuid : bool (optional)
Whether to use the uuid (str) to represent AiiDA nodes instead of the primary key
backend_module = _return_ext_interface_modules("aiida")
structure_dict = backend_module._extract_dict_from_aiida_structure_node(
structure_node, use_uuid
if label is not None:
structure_dict["label"] = label
return cls(**structure_dict)
def to_dict(
cartesian: bool = True,
wrap: bool = False,
include_calculated_properties: bool = False,
) -> dict:
Export structure to python dictionary.
cartesian : bool (optional)
Whether cartesian or scaled coordinates are returned.
wrap : bool (optional)
Whether the coordinates are wrapped back into the unit cell.
include_calculated_properties : bool (optional)
Include ``extras`` and ``function_args`` in the dictionary as well.
Dictionary representing the structure. The ``Structure`` object can be retrieved via
# TODO add test:
calc_prop_keys = ["extras", "function_args"]
strct_dict = {}
for key in self.keys():
if (not include_calculated_properties and key in calc_prop_keys) or key == "positions":
strct_dict[key] = getattr(self, key)
strct_dict["positions"] = self.get_positions(cartesian=cartesian, wrap=wrap)
if not cartesian:
strct_dict["is_cartesian"] = False
return strct_dict
def to_file(self, file_path: str) -> None:
Export structure to file using the ase interface or certain file formats for Zeo++.
if file_path.endswith((".cssr", ".v1", ".cuc")):
zeo.write_to_file(self, file_path)
backend_module = _return_ext_interface_modules("ase_atoms")
backend_module._write_structure_to_file(self, file_path)
def to_ase_atoms(self) -> Atoms:
Create ase Atoms object.
ase Atoms object of the structure.
backend_module = _return_ext_interface_modules("ase_atoms")
return backend_module._create_atoms_from_structure(self)
def to_pymatgen_structure(self) -> Union["pymatgen.core.Molecule", "pymatgen.core.Structure"]:
Create pymatgen Structure (if cell is not `None`) or Molecule (if cell is `None`) object.
pymatgen.core.Structure or pymatgen.core.Molecule
pymatgen structure or molecule object.
backend_module = _return_ext_interface_modules("pymatgen")
return backend_module._create_pymatgen_obj(self)
def to_aiida_structuredata(self, label=None):
Create AiiDA structuredata.
AiiDA structure node.
backend_module = _return_ext_interface_modules("aiida")
return backend_module._create_structure_node(self)
def _wrap_position(self, cart_position, scaled_position):
"""Wrap position back into the unit cell."""
if self.cell is None:
return cart_position, scaled_position
if cart_position is not None:
cart_position = np.array(cart_position)
if scaled_position is not None:
scaled_position = np.array(scaled_position)
if scaled_position is None:
scaled_position = np.transpose(np.array(self._inverse_cell)).dot(cart_position)
for direction in range(3):
if self.pbc[direction]:
scaled_position[direction] = round(scaled_position[direction], 15) % 1
cart_position = np.transpose(np.array(self.cell)).dot(scaled_position)
return tuple(float(p) for p in cart_position), tuple(float(p) for p in scaled_position)
def _perform_strct_analysis(self, method, kwargs):
return _check_calculated_properties(self, method, kwargs)
def _perform_strct_manipulation(self, method, kwargs):
new_strct = method(structure=self, **kwargs)
if isinstance(new_strct, Structure):
return new_strct
elif isinstance(new_strct, dict):
return Structure(**new_strct)
return self