"""Methods to fit the cell parameters of crystalline materials."""
# Standard library imports
import itertools
# Third party library imports
import numpy as np
from scipy.spatial.transform import Rotation
# Internal library imports
from aim2dat.strct import StructureOperations, StructureCollection, Structure
from aim2dat.utils.space_groups import get_lattice_type
from aim2dat.utils.maths import calc_angle
[docs]
class CellGridSearch:
"""
Class to fit the cell parameters of an initial structure to a final structure using a
brute-force grid search approach. The space group is maintained during the fitting process.
Attributes
----------
length_scaling_factors : list
Scaling factors for the cell lengths.
angle_scaling_factors : list
Scaling factors for the cell angles.
symprec : float
Tolerance for spglib and length and angle comparison.
angle_tolerance : float
Tolerance parameter for spglib.
hall_number : int (optional)
The argument to constrain the space-group-type search only for the Hall symbol
corresponding to it.
ffprint_r_max : float
Cut-off value for the maximum distance between two atoms.
ffprint_delta_bin : float (optional)
Bin size to descritize the function.
ffprint_sigma : float (optional)
Smearing parameter for the Gaussian function.
ffprint_use_weights : bool (optional)
Whether to use importance weights for the element pairs.
ffprint_distinguish_kinds: bool (optional)
Whether different kinds should be distinguished e.g. Ni0 and Ni1 would be considered as
different elements if ``True``.
target_value : float (optional)
Target value used to calculate score if a model is set via the ``set_model`` function.
"""
def __init__(
self,
length_scaling_factors=[0.8, 1.0, 1.2],
angle_scaling_factors=[0.9, 1.0, 1.1],
symprec=0.005,
angle_tolerance=-1.0,
hall_number=0,
ffprint_r_max=10.0,
ffprint_delta_bin=0.005,
ffprint_sigma=0.05,
ffprint_use_weights=True,
ffprint_distinguish_kinds=False,
target_value=0.0,
):
"""Construct object."""
self._strct_ops = StructureOperations(structures=StructureCollection())
self._transformer = None
self._model = None
self._fit_info = None
self.length_scaling_factors = length_scaling_factors
self.angle_scaling_factors = angle_scaling_factors
self.symprec = symprec
self.angle_tolerance = angle_tolerance
self.hall_number = hall_number
self.ffprint_r_max = ffprint_r_max
self.ffprint_delta_bin = ffprint_delta_bin
self.ffprint_sigma = ffprint_sigma
self.ffprint_use_weights = ffprint_use_weights
self.ffprint_distinguish_kinds = ffprint_distinguish_kinds
self.target_value = target_value
[docs]
def set_initial_structure(self, structure):
"""
Set initial crystal structure.
Parameters
----------
structure : aim2dat.strct.Structure
Initial structure.
"""
structure = structure.copy()
self._strct_ops.structures["initial"] = structure
[docs]
def set_model(self, model, function_name="predict", single=False, transformer=None):
"""
Set scikit-learn model to predict the target value.
Parameters
----------
model :
Object that takes structures or features as input to predicts a target value.
function_name : str (optional)
Function name to retrieve the property prediction.
single : bool (optional)
Whether a single structure/features or a list of structures/features is predicted at
once.
transformer : aim2dat.ml.transformers (optional)
Structure transformer.
"""
if transformer is not None:
self._transformer = transformer
self._model_fct = (function_name, single)
self._model = model
[docs]
def set_target_structure(self, structure):
"""
Set target crystal structure.
Parameters
----------
structure : aim2dat.strct.Structure
Target structure.
"""
structure = structure.copy()
self._strct_ops.structures["target"] = structure
[docs]
def get_optimized_structure(self):
"""
Get optimized structure with the lowest score.
Returns
-------
: aim2dat.strct.Structure
Optimized structure.
"""
if self._fit_info is None:
self.fit()
return self._strct_ops.structures[self._fit_info[0]].copy()
[docs]
def return_search_space(self):
"""
Return list of parameter sets that are varied to fit the initial to the final structure.
Returns
-------
list
List of parameter sets that are varied.
"""
search_space = []
space_group = self._strct_ops["initial"].determine_space_group(
symprec=self.symprec,
angle_tolerance=self.angle_tolerance,
hall_number=self.hall_number,
)
lattice_type = get_lattice_type(space_group["space_group"]["number"])
print(
"Space group of initial crystal: ",
space_group["space_group"]["number"],
"(" + lattice_type + ")",
)
cell = self._strct_ops.structures["initial"]["cell"]
if lattice_type == "triclinic":
length_combinations = list(itertools.product(self.length_scaling_factors, repeat=3))
angle_combinations = list(itertools.product(self.length_scaling_factors, repeat=3))
for l_comb in length_combinations:
for a_comb in angle_combinations:
search_space.append([sf for sf in l_comb + a_comb])
elif lattice_type == "monoclinic":
comb = self._check_length_angles(cell, self.symprec, same_length=False, angles=[90.0])
if len(comb) != 2:
raise ValueError("Could not detect monoclinic lattice type.")
angle_idx = [idx0 for idx0 in range(3) if idx0 in comb[0] and idx0 in comb[1]]
length_combinations = list(itertools.product(self.length_scaling_factors, repeat=3))
angle_combinations = self.angle_scaling_factors
for l_comb in length_combinations:
for a_comb in angle_combinations:
param = [l_comb[0], l_comb[1], l_comb[2], 1.0, 1.0, 1.0]
param[3 + angle_idx[0]] = a_comb
search_space.append(param)
elif lattice_type == "orthorhombic":
comb = self._check_length_angles(cell, self.symprec, same_length=False, angles=[90.0])
if len(comb) != 3:
raise ValueError("Could not detect orthorhombic lattice type.")
for scaling_factor_a in self.length_scaling_factors:
for scaling_factor_b in self.length_scaling_factors:
for scaling_factor_c in self.length_scaling_factors:
search_space.append(
[scaling_factor_a, scaling_factor_b, scaling_factor_c, 1.0, 1.0, 1.0]
)
elif lattice_type == "tetragonal":
comb = self._check_length_angles(cell, self.symprec, same_length=True, angles=[90.0])
if len(comb) != 1:
raise ValueError("Could not detect tetragonal lattice type.")
length_combinations = list(itertools.product(self.length_scaling_factors, repeat=2))
for l_comb in length_combinations:
param = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
for idx0 in range(3):
if idx0 in comb[0]:
param[idx0] = l_comb[0]
else:
param[idx0] = l_comb[1]
search_space.append(param)
elif lattice_type == "trigonal" or lattice_type == "hexagonal":
comb = self._check_length_angles(
cell, self.symprec, same_length=True, angles=[60.0, 120.0]
)
if len(comb) != 1:
raise ValueError("Could not detect trigonal or hexagonal lattice type.")
for scaling_factor_ab in self.length_scaling_factors:
for scaling_factor_c in self.length_scaling_factors:
params = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
for idx0 in range(3):
if idx0 in comb[0]:
params[idx0] = scaling_factor_ab
else:
params[idx0] = scaling_factor_c
search_space.append(params)
elif lattice_type == "cubic":
comb = self._check_length_angles(cell, self.symprec, same_length=True, angles=[90.0])
if len(comb) != 3:
raise ValueError("Could not detect cubic lattice type.")
for sf_abc in self.length_scaling_factors:
search_space.append([sf_abc, sf_abc, sf_abc, 1.0, 1.0, 1.0])
return search_space
[docs]
def fit(self, search_space=None):
"""
Fit the initial to the final structure by varying the cell parameters.
Parameters
----------
search_space : list or None
Defines the cell parameter variations. If set to ``None`` the parameters are obtained
via the ``return_search_space``-function.
Returns
-------
max_score : float
Score of the best match.
max_params : list
Parameters that give the best match.
"""
if search_space is None:
search_space = self.return_search_space()
initial_sg = self._strct_ops["initial"].determine_space_group(
symprec=self.symprec,
angle_tolerance=self.angle_tolerance,
hall_number=self.hall_number,
)["space_group"]["number"]
initial_strct = self._strct_ops.structures["initial"]
labels = []
for idx0, params in enumerate(search_space):
cell = np.array(initial_strct["cell"])
for vec_idx, scaling_factor in enumerate(params[:3]):
cell[vec_idx] *= scaling_factor
for angle_idx, scaling_factor in enumerate(params[3:]):
vec_indices = [idx0 for idx0 in range(3) if idx0 != angle_idx]
rot_v = cell[angle_idx].copy()
rot_v /= np.linalg.norm(rot_v)
rot_angle = calc_angle(cell[vec_indices[0]], cell[vec_indices[1]])
rot_angle *= scaling_factor - 1.0
rot = Rotation.from_rotvec(rot_angle * rot_v)
cell = np.dot(rot.as_matrix(), cell.T).T
self._strct_ops.structures[str(idx0)] = Structure(
elements=initial_strct.elements,
positions=initial_strct.scaled_positions,
pbc=initial_strct.pbc,
cell=cell,
is_cartesian=False,
)
trial_sg = self._strct_ops[str(idx0)].determine_space_group(
symprec=self.symprec,
angle_tolerance=self.angle_tolerance,
hall_number=self.hall_number,
)["space_group"]["number"]
if initial_sg != trial_sg:
raise ValueError("Space groups don't match!")
labels.append(str(idx0))
scores = self._calculate_scores(labels)
min_score = scores[0]
min_label = labels[0]
min_params = search_space[0]
for label, score, params in zip(labels, scores, search_space):
if score < min_score:
min_label = label
min_score = score
min_params = params
self._fit_info = (min_label, min_score, min_params)
return min_score, min_params
[docs]
def return_initial_score(self):
"""
Return score of the initial structure.
Returns
-------
float
Score of the initial structure.
"""
return self._calculate_scores(["initial"])[0]
def _calculate_scores(self, labels):
if "target" in self._strct_ops.structures.labels:
return self._compare_with_target_structure_ffprint(labels)
elif self._model is not None:
return self._get_model_predictions(labels)
def _compare_with_target_structure_ffprint(self, labels):
scores = []
comparisons = self._strct_ops.compare_structures_via_ffingerprint(
labels,
["target"] * len(labels),
r_max=self.ffprint_r_max,
delta_bin=self.ffprint_delta_bin,
sigma=self.ffprint_sigma,
use_weights=self.ffprint_use_weights,
distinguish_kinds=self.ffprint_distinguish_kinds,
)
for label in labels:
scores.append(comparisons[(label, "target")])
return scores
def _get_model_predictions(self, labels):
predict_fct = getattr(self._model, self._model_fct[0])
features = [strct for strct in self._strct_ops.structures if strct.label in labels]
if self._transformer is not None:
features = self._transformer.transform(features)
if self._model_fct[1]:
predictions = [abs(predict_fct(feat) - self.target_value) for feat in features]
else:
predictions = np.absolute(predict_fct(features) - self.target_value).tolist()
return predictions
@staticmethod
def _check_length_angles(cell, tol, same_length=False, angles=None):
cell = np.array(cell)
found_combintations = []
for comb in [(0, 1), (0, 2), (1, 2)]:
comb_found = True
if (
same_length
and abs(np.linalg.norm(cell[comb[0]]) - np.linalg.norm(cell[comb[1]])) > tol
):
comb_found = False
if angles is not None:
angle = calc_angle(cell[comb[0]], cell[comb[1]]) * 180.0 / np.pi
if all(abs(angle - ref) > tol for ref in angles):
comb_found = False
if comb_found:
found_combintations.append(comb)
return found_combintations