"""Methods to fit the cell parameters of crystalline materials."""
# Standard library imports
import itertools
# Third party library imports
import numpy as np
from scipy.spatial.transform import Rotation
# Internal library imports
from aim2dat.strct import StructureOperations, StructureCollection, Structure
from aim2dat.utils.space_groups import get_lattice_type
from aim2dat.utils.maths import calc_angle
[docs]
class CellGridSearch:
    """
    Class to fit the cell parameters of an initial structure to a final structure using a
    brute-force grid search approach. The space group is maintained during the fitting process.
    Attributes
    ----------
    length_scaling_factors : list
        Scaling factors for the cell lengths.
    angle_scaling_factors : list
        Scaling factors for the cell angles.
    symprec : float
        Tolerance for spglib and length and angle comparison.
    angle_tolerance : float
        Tolerance parameter for spglib.
    hall_number : int (optional)
        The argument to constrain the space-group-type search only for the Hall symbol
        corresponding to it.
    ffprint_r_max : float
        Cut-off value for the maximum distance between two atoms.
    ffprint_delta_bin : float (optional)
        Bin size to descritize the function.
    ffprint_sigma : float (optional)
        Smearing parameter for the Gaussian function.
    ffprint_use_weights : bool (optional)
        Whether to use importance weights for the element pairs.
    ffprint_distinguish_kinds: bool (optional)
        Whether different kinds should be distinguished e.g. Ni0 and Ni1 would be considered as
        different elements if ``True``.
    target_value : float (optional)
        Target value used to calculate score if a model is set via the ``set_model`` function.
    """
    def __init__(
        self,
        length_scaling_factors=[0.8, 1.0, 1.2],
        angle_scaling_factors=[0.9, 1.0, 1.1],
        symprec=0.005,
        angle_tolerance=-1.0,
        hall_number=0,
        ffprint_r_max=10.0,
        ffprint_delta_bin=0.005,
        ffprint_sigma=0.05,
        ffprint_use_weights=True,
        ffprint_distinguish_kinds=False,
        target_value=0.0,
    ):
        """Construct object."""
        self._strct_ops = StructureOperations(structures=StructureCollection())
        self._transformer = None
        self._model = None
        self._fit_info = None
        self.length_scaling_factors = length_scaling_factors
        self.angle_scaling_factors = angle_scaling_factors
        self.symprec = symprec
        self.angle_tolerance = angle_tolerance
        self.hall_number = hall_number
        self.ffprint_r_max = ffprint_r_max
        self.ffprint_delta_bin = ffprint_delta_bin
        self.ffprint_sigma = ffprint_sigma
        self.ffprint_use_weights = ffprint_use_weights
        self.ffprint_distinguish_kinds = ffprint_distinguish_kinds
        self.target_value = target_value
[docs]
    def set_initial_structure(self, structure):
        """
        Set initial crystal structure.
        Parameters
        ----------
        structure : aim2dat.strct.Structure
            Initial structure.
        """
        structure = structure.copy()
        self._strct_ops.structures["initial"] = structure 
[docs]
    def set_model(self, model, function_name="predict", single=False, transformer=None):
        """
        Set scikit-learn model to predict the target value.
        Parameters
        ----------
        model :
            Object that takes structures or features as input to predicts a target value.
        function_name : str (optional)
            Function name to retrieve the property prediction.
        single : bool (optional)
            Whether a single structure/features or a list of structures/features is predicted at
            once.
        transformer : aim2dat.ml.transformers (optional)
            Structure transformer.
        """
        if transformer is not None:
            self._transformer = transformer
        self._model_fct = (function_name, single)
        self._model = model 
[docs]
    def set_target_structure(self, structure):
        """
        Set target crystal structure.
        Parameters
        ----------
        structure : aim2dat.strct.Structure
            Target structure.
        """
        structure = structure.copy()
        self._strct_ops.structures["target"] = structure 
[docs]
    def get_optimized_structure(self):
        """
        Get optimized structure with the lowest score.
        Returns
        -------
        : aim2dat.strct.Structure
            Optimized structure.
        """
        if self._fit_info is None:
            self.fit()
        return self._strct_ops.structures[self._fit_info[0]].copy() 
[docs]
    def return_search_space(self):
        """
        Return list of parameter sets that are varied to fit the initial to the final structure.
        Returns
        -------
        list
            List of parameter sets that are varied.
        """
        search_space = []
        space_group = self._strct_ops["initial"].determine_space_group(
            symprec=self.symprec,
            angle_tolerance=self.angle_tolerance,
            hall_number=self.hall_number,
        )
        lattice_type = get_lattice_type(space_group["space_group"]["number"])
        print(
            "Space group of initial crystal: ",
            space_group["space_group"]["number"],
            "(" + lattice_type + ")",
        )
        cell = self._strct_ops.structures["initial"]["cell"]
        if lattice_type == "triclinic":
            length_combinations = list(itertools.product(self.length_scaling_factors, repeat=3))
            angle_combinations = list(itertools.product(self.length_scaling_factors, repeat=3))
            for l_comb in length_combinations:
                for a_comb in angle_combinations:
                    search_space.append([sf for sf in l_comb + a_comb])
        elif lattice_type == "monoclinic":
            comb = self._check_length_angles(cell, self.symprec, same_length=False, angles=[90.0])
            if len(comb) != 2:
                raise ValueError("Could not detect monoclinic lattice type.")
            angle_idx = [idx0 for idx0 in range(3) if idx0 in comb[0] and idx0 in comb[1]]
            length_combinations = list(itertools.product(self.length_scaling_factors, repeat=3))
            angle_combinations = self.angle_scaling_factors
            for l_comb in length_combinations:
                for a_comb in angle_combinations:
                    param = [l_comb[0], l_comb[1], l_comb[2], 1.0, 1.0, 1.0]
                    param[3 + angle_idx[0]] = a_comb
                    search_space.append(param)
        elif lattice_type == "orthorhombic":
            comb = self._check_length_angles(cell, self.symprec, same_length=False, angles=[90.0])
            if len(comb) != 3:
                raise ValueError("Could not detect orthorhombic lattice type.")
            for scaling_factor_a in self.length_scaling_factors:
                for scaling_factor_b in self.length_scaling_factors:
                    for scaling_factor_c in self.length_scaling_factors:
                        search_space.append(
                            [scaling_factor_a, scaling_factor_b, scaling_factor_c, 1.0, 1.0, 1.0]
                        )
        elif lattice_type == "tetragonal":
            comb = self._check_length_angles(cell, self.symprec, same_length=True, angles=[90.0])
            if len(comb) != 1:
                raise ValueError("Could not detect tetragonal lattice type.")
            length_combinations = list(itertools.product(self.length_scaling_factors, repeat=2))
            for l_comb in length_combinations:
                param = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
                for idx0 in range(3):
                    if idx0 in comb[0]:
                        param[idx0] = l_comb[0]
                    else:
                        param[idx0] = l_comb[1]
                search_space.append(param)
        elif lattice_type == "trigonal" or lattice_type == "hexagonal":
            comb = self._check_length_angles(
                cell, self.symprec, same_length=True, angles=[60.0, 120.0]
            )
            if len(comb) != 1:
                raise ValueError("Could not detect trigonal or hexagonal lattice type.")
            for scaling_factor_ab in self.length_scaling_factors:
                for scaling_factor_c in self.length_scaling_factors:
                    params = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]
                    for idx0 in range(3):
                        if idx0 in comb[0]:
                            params[idx0] = scaling_factor_ab
                        else:
                            params[idx0] = scaling_factor_c
                    search_space.append(params)
        elif lattice_type == "cubic":
            comb = self._check_length_angles(cell, self.symprec, same_length=True, angles=[90.0])
            if len(comb) != 3:
                raise ValueError("Could not detect cubic lattice type.")
            for sf_abc in self.length_scaling_factors:
                search_space.append([sf_abc, sf_abc, sf_abc, 1.0, 1.0, 1.0])
        return search_space 
[docs]
    def fit(self, search_space=None):
        """
        Fit the initial to the final structure by varying the cell parameters.
        Parameters
        ----------
        search_space : list or None
            Defines the cell parameter variations. If set to ``None`` the parameters are obtained
            via the ``return_search_space``-function.
        Returns
        -------
        max_score : float
            Score of the best match.
        max_params : list
            Parameters that give the best match.
        """
        if search_space is None:
            search_space = self.return_search_space()
        initial_sg = self._strct_ops["initial"].determine_space_group(
            symprec=self.symprec,
            angle_tolerance=self.angle_tolerance,
            hall_number=self.hall_number,
        )["space_group"]["number"]
        initial_strct = self._strct_ops.structures["initial"]
        labels = []
        for idx0, params in enumerate(search_space):
            cell = np.array(initial_strct["cell"])
            for vec_idx, scaling_factor in enumerate(params[:3]):
                cell[vec_idx] *= scaling_factor
            for angle_idx, scaling_factor in enumerate(params[3:]):
                vec_indices = [idx0 for idx0 in range(3) if idx0 != angle_idx]
                rot_v = cell[angle_idx].copy()
                rot_v /= np.linalg.norm(rot_v)
                rot_angle = calc_angle(cell[vec_indices[0]], cell[vec_indices[1]])
                rot_angle *= scaling_factor - 1.0
                rot = Rotation.from_rotvec(rot_angle * rot_v)
                cell = np.dot(rot.as_matrix(), cell.T).T
            self._strct_ops.structures[str(idx0)] = Structure(
                elements=initial_strct.elements,
                positions=initial_strct.scaled_positions,
                pbc=initial_strct.pbc,
                cell=cell,
                is_cartesian=False,
            )
            trial_sg = self._strct_ops[str(idx0)].determine_space_group(
                symprec=self.symprec,
                angle_tolerance=self.angle_tolerance,
                hall_number=self.hall_number,
            )["space_group"]["number"]
            if initial_sg != trial_sg:
                raise ValueError("Space groups don't match!")
            labels.append(str(idx0))
        scores = self._calculate_scores(labels)
        min_score = scores[0]
        min_label = labels[0]
        min_params = search_space[0]
        for label, score, params in zip(labels, scores, search_space):
            if score < min_score:
                min_label = label
                min_score = score
                min_params = params
        self._fit_info = (min_label, min_score, min_params)
        return min_score, min_params 
[docs]
    def return_initial_score(self):
        """
        Return score of the initial structure.
        Returns
        -------
        float
            Score of the initial structure.
        """
        return self._calculate_scores(["initial"])[0] 
    def _calculate_scores(self, labels):
        if "target" in self._strct_ops.structures.labels:
            return self._compare_with_target_structure_ffprint(labels)
        elif self._model is not None:
            return self._get_model_predictions(labels)
    def _compare_with_target_structure_ffprint(self, labels):
        scores = []
        comparisons = self._strct_ops.compare_structures_via_ffingerprint(
            labels,
            ["target"] * len(labels),
            r_max=self.ffprint_r_max,
            delta_bin=self.ffprint_delta_bin,
            sigma=self.ffprint_sigma,
            use_weights=self.ffprint_use_weights,
            distinguish_kinds=self.ffprint_distinguish_kinds,
        )
        for label in labels:
            scores.append(comparisons[(label, "target")])
        return scores
    def _get_model_predictions(self, labels):
        predict_fct = getattr(self._model, self._model_fct[0])
        features = [strct for strct in self._strct_ops.structures if strct.label in labels]
        if self._transformer is not None:
            features = self._transformer.transform(features)
        if self._model_fct[1]:
            predictions = [abs(predict_fct(feat) - self.target_value) for feat in features]
        else:
            predictions = np.absolute(predict_fct(features) - self.target_value).tolist()
        return predictions
    @staticmethod
    def _check_length_angles(cell, tol, same_length=False, angles=None):
        cell = np.array(cell)
        found_combintations = []
        for comb in [(0, 1), (0, 2), (1, 2)]:
            comb_found = True
            if (
                same_length
                and abs(np.linalg.norm(cell[comb[0]]) - np.linalg.norm(cell[comb[1]])) > tol
            ):
                comb_found = False
            if angles is not None:
                angle = calc_angle(cell[comb[0]], cell[comb[1]]) * 180.0 / np.pi
                if all(abs(angle - ref) > tol for ref in angles):
                    comb_found = False
            if comb_found:
                found_combintations.append(comb)
        return found_combintations