Source code for aim2dat.io.cif

"""Read and write cif files."""

# Standard library imports
from warnings import warn

# Third party library imports
import numpy as np
import re
from scipy.spatial.distance import cdist

# Internal library imports
from aim2dat.io.utils import read_structure
from aim2dat.io.base_parser import FLOAT
from aim2dat.strct.strct_misc import _get_cell_from_lattice_p
from aim2dat.utils.element_properties import get_element_symbol
from aim2dat.utils.space_groups import get_space_group_details
import aim2dat.utils.chem_formula as utils_cf


class _CIFDataBlock:
    """Class to process and store data blocks of cif-files."""

    _string_limiters = ["'", '"']
    _cell_fields = [
        "cell_length_a",
        "cell_length_b",
        "cell_length_c",
        "cell_angle_alpha",
        "cell_angle_beta",
        "cell_angle_gamma",
    ]
    _atomic_site_coord_fields = [
        "atom_site_fract_x",
        "atom_site_fract_y",
        "atom_site_fract_z",
    ]
    _symmetry_fields = ["symmetry_equiv_pos_as_xyz", "space_group_symop_operation_xyz"]
    _space_group_fields = [
        "space_group_name_hall",
        "symmetry_space_group_name_hall",
        "symmetry_space_group_name_h-m",
        "symmetry_space_group_name_h-m",
        "symmetry_int_tables_number",
        "space_group_it_number",
    ]
    _chem_formula_fields = ["chemical_formula_sum"]
    _pred_element_mapping = {"Ow": "O", "Hw": "H"}
    _patterns = [
        (re.compile(r"^([+-]?[0-9]+)$"), int),
        (re.compile(r"^" + FLOAT), float),
    ]
    _sym_op_pattern = re.compile(
        rf"(?P<sign>[-+])?(?P<num>({FLOAT}))?(\/(?P<den>{FLOAT}))?(?P<coord>[x-z])?"
    )

    def __init__(self, title_line):
        self.title = "_".join(title_line.split("_")[1:])
        self.fields = {}
        self.loops = []
        self.in_loop = False
        self.in_ml_field = False
        self.current_multi_line_field = None
        self.current_loop = None

    def add_line(self, line_idx, line):
        line_tr = line.strip().lower()

        # Omit comment or empty line:
        if line.startswith("#") or line == "":
            return None

        # Truncate if comment starts somewhere midline and is not part of string value:
        if "#" in line:
            lt = line.split("#")
            if len(lt) == 1:
                line = lt[0]
            else:
                l1 = lt[0]
                l2 = "#".join(lt[1:])
                if not any([str_l in l1 and str_l in l2 for str_l in self._string_limiters]):
                    line = l1

        # Start of loop:
        if line_tr.startswith("loop_"):
            # In some cases multi line fields are not properly limited via semicolons.
            self._finalize_current_loop(line_idx, "")
            self._force_finalize_multi_line()

            self.current_loop = {"keys": [], "values": []}
            line_split = line.split()
            if len(line_split) > 1:
                line = " ".join(line.split()[1:])
            else:
                line = ""
            self._add_line_to_loop(line_idx, line)

        # Process loop content:
        elif self.current_loop is not None:
            self._add_line_to_loop(line_idx, line)

        # Field:
        elif line_tr.startswith("_"):
            # In some cases multi line fields are not properly limited via semicolons.
            self._force_finalize_multi_line()

            line_split = line.split()
            if len(line_split) > 1:
                self.fields[self._transf_key(line_split[0])] = self._transf_value(
                    " ".join(line_split[1:])
                )
            else:
                self.current_multi_line_field = [self._transf_key(line_split[0])]

        # Multi-line field:
        elif self.current_multi_line_field is not None:
            self._extract_multi_line_value(line_idx, line)

    def finalize(self, line_idx):
        self._finalize_current_loop(line_idx, "")
        self._force_finalize_multi_line()

    def get_output(self):
        outp_dict = self.fields.copy()
        outp_dict["loops"] = self.loops
        return outp_dict

    def get_cell(self):
        if all(f0 in self.fields for f0 in self._cell_fields):
            return _get_cell_from_lattice_p(*[self.fields[f0] for f0 in self._cell_fields])

    def get_atomic_sites(self, check_chem_formula, get_sym_op_from_sg):
        # Get original sites:
        kinds = []
        scaled_coords = []
        kind_el_mapping = {}
        prel_site_attributes = {}
        for loop in self.loops:
            if "atom_site_label" not in loop:
                continue

            if all(k0 in loop for k0 in self._atomic_site_coord_fields):
                scaled_coords0 = []
                for key in self._atomic_site_coord_fields:
                    scaled_coords0.append(loop[key])
                scaled_coords += list(zip(*scaled_coords0))
                kinds += loop["atom_site_label"]
            if "atom_site_type_symbol" in loop:
                for kind, el in zip(loop["atom_site_label"], loop["atom_site_type_symbol"]):
                    kind_el_mapping[kind] = self._extract_element(el)
            for key, values in loop.items():
                if (
                    key
                    not in ["atom_site_label", "atom_site_type_symbol"]
                    + self._atomic_site_coord_fields
                ):
                    prel_site_attributes[key] = values.copy()

        # Try to get element symbols from site labels:
        if len(kind_el_mapping) == 0:
            for kind in kinds:
                kind_el_mapping[kind] = self._extract_element(kind)

        # In case site attributes are given, check consistency:
        site_attributes = {}
        for key, val in prel_site_attributes.items():
            if len(val) == len(kinds):
                site_attributes[key] = val

        # Add sites from symmetry operations:
        sym_ops = self.get_symmetry_operations(get_sym_op_from_sg)
        sym_ops = [(np.array(rot), np.array(trans)) for rot, trans in sym_ops]
        scaled_coords_bf = [[round(p0, 15) % 1 for p0 in pos] for pos in scaled_coords]
        n_sites = len(kinds)
        for sym_op in sym_ops:
            for idx in range(n_sites):
                new_pos = np.dot(sym_op[0], np.array(scaled_coords[idx])) + sym_op[1]
                new_pos_bf = [round(val, 15) % 1 for val in new_pos]
                mask = [kind == kinds[idx] for kind in kinds]
                dists = cdist(np.array([new_pos_bf]), np.array(scaled_coords_bf)[mask])[0]
                if any(dist < 1e-3 for dist in dists):
                    continue
                kinds.append(kinds[idx])
                scaled_coords.append(tuple(new_pos.tolist()))
                scaled_coords_bf.append(new_pos_bf)
                for val in site_attributes.values():
                    val.append(val[idx])

        # In case elements and kinds coincide only elements is considered:
        elements = [kind_el_mapping[kind] for kind in kinds]
        if all(el == k for el, k in zip(elements, kinds)):
            kinds = None

        # Check sum formula in case given:
        if check_chem_formula:
            chem_formula = self.get_chem_formula()
            if chem_formula is not None:
                if not utils_cf.compare_formulas(
                    chem_formula, utils_cf.transform_list_to_dict(elements), reduce_formulas=True
                ):
                    raise ValueError("Chemical formula doesn't match with number of sites.")

        return elements, kinds, scaled_coords, site_attributes

    def get_symmetry_operations(self, get_sym_op_from_sg):
        sym_op_strings = []
        for loop in self.loops:
            for field_key in self._symmetry_fields:
                if field_key in loop:
                    sym_op_strings += loop[field_key]
        sym_ops = []
        for sym_op in sym_op_strings:
            rot_matrix = np.zeros((3, 3))
            shift = np.zeros(3)
            for coord_idx, sym_str in enumerate(sym_op.split(",")):
                sym_str = sym_str.replace(" ", "").lower()
                for m in self._sym_op_pattern.finditer(sym_str):
                    m = m.groupdict()
                    if all(val is None for val in m.values()):
                        continue
                    pref = -1.0 if m["sign"] == "-" else 1.0
                    pref *= float(m["num"]) if m["num"] is not None else 1.0
                    pref /= float(m["den"]) if m["den"] is not None else 1.0
                    if m["coord"] is None:
                        shift[coord_idx] += pref
                    else:
                        rot_matrix[coord_idx]["xyz".index(m["coord"])] += pref
            sym_ops.append((rot_matrix.tolist(), shift.tolist()))
        if get_sym_op_from_sg and len(sym_ops) == 0:
            sg_details = self.get_space_group_details(return_sym_operations=True)
            if sg_details is not None:
                warn(
                    "Could not determine symmetry operations directly, using space group details.",
                    UserWarning,
                )
                sym_ops += sg_details["symmetry_operations"]
        return sym_ops

    def get_space_group_details(self, return_sym_operations=False):
        for key in self._space_group_fields:
            if key in self.fields:
                return get_space_group_details(
                    self.fields[key], return_sym_operations=return_sym_operations
                )

    def get_chem_formula(self):
        for key in self._chem_formula_fields:
            if key in self.fields:
                return utils_cf.transform_str_to_dict(self.fields[key])

    def _add_line_to_loop(self, line_idx, line):
        # Adding loop labels:
        if line.startswith("_"):
            # If line starts with new field label and values have been added we assume that the
            # loop is completed.
            if len(self.current_loop["values"]) > 0:
                self._finalize_current_loop(line_idx, line)
                self.add_line(line_idx, line)
            else:
                self.current_loop["keys"] += [self._transf_key(sp[1:]) for sp in line.split()]

        # Adding loop values:
        else:
            # Multi-line field in loop:
            if line.startswith(";") or self.current_multi_line_field is not None:
                if self.current_multi_line_field is None:
                    self.current_multi_line_field = [""]
                self._extract_multi_line_value(line_idx, line)
            # Single value fields:
            else:
                self._add_loop_values(self._extract_loop_values(line))

    def _add_loop_values(self, loop_values):
        if len(self.current_loop["values"]) == 0:
            start_idx = 0
            self.current_loop["values"] = [[] for _ in self.current_loop["keys"]]
        else:
            val_lengths = [len(val) for val in self.current_loop["values"]]
            start_idx = val_lengths.index(min(val_lengths))
        for val_idx, val in enumerate(loop_values):
            val_idx += start_idx
            if len(self.current_loop["keys"]) > 0:
                val_idx %= len(self.current_loop["keys"])
            self.current_loop["values"][val_idx].append(val)

    def _extract_multi_line_value(self, line_idx, line):
        finalize = False
        add_str = None
        new_str = ""
        line_split = line.split(";")
        if len(line_split) == 1:
            add_str = line
        else:
            if len(line_split) + len(self.current_multi_line_field) > 3:
                finalize = True
            add_str = line_split[0]
            new_str = line_split[-1]
            if line_split[0] == "" and len(self.current_multi_line_field) == 1:
                add_str = line_split[1]
                if len(line_split) == 2:
                    new_str = ""
        if add_str is not None:
            if len(self.current_multi_line_field) == 1:
                self.current_multi_line_field.append(add_str)
            else:
                self.current_multi_line_field[1] += "\n" + add_str

        if finalize:
            self._force_finalize_multi_line()
        self.add_line(line_idx, new_str)

    def _force_finalize_multi_line(self):
        if self.current_multi_line_field is None:
            return None

        val = ""
        if len(self.current_multi_line_field) == 2:
            val = self.current_multi_line_field[1]
        if self.current_loop is None:
            self.fields[self.current_multi_line_field[0]] = val
        else:
            self._add_loop_values([val])
        self.current_multi_line_field = None

    def _extract_loop_values(self, line):
        line_split = line.split()
        loop_values = []
        str_val_limiter = None
        string_val = None
        for val in line_split:
            if val[0] in self._string_limiters and string_val is None:
                str_val_limiter = val[0]
                string_val = val[1:] if len(val) > 1 else ""
            elif string_val is not None:
                string_val += " " + val
                if val.endswith(str_val_limiter):
                    string_val = string_val[:-1]
                    loop_values.append(string_val)
                    string_val = None
            else:
                loop_values.append(val)

        return [self._transf_value(val) for val in loop_values]

    def _finalize_current_loop(self, line_idx, line):
        if self.current_loop is None:
            return None

        self._force_finalize_multi_line()
        if any(
            len(self.current_loop["values"][0]) != len(self.current_loop["values"][idx])
            for idx in range(len(self.current_loop["values"]))
        ):
            raise ValueError(f"Number of values differ for loop finishing on line {line_idx}.")
        self.loops.append(
            {
                key: self.current_loop["values"][idx]
                for idx, key in enumerate(self.current_loop["keys"])
            }
        )
        self.current_loop = None

    @staticmethod
    def _transf_key(key):
        return key.strip("_").lower()

    def _transf_value(self, value):
        value = value.strip()
        for sl in self._string_limiters:
            value = value.strip(sl)
        for pattern, t in self._patterns:
            match = pattern.match(value)
            if match and len(match.group(1)) > 0:
                return t(match.group(1))
        return value

    def _extract_element(self, value):
        value = re.split(r"(\d)|(_)|(-)|(\+)", value)[0]
        el = self._pred_element_mapping.get(value, None)
        if el is None:
            try:
                el = get_element_symbol(value)
            except ValueError:
                raise ValueError(f"Could not determine element of '{value}'.")
        return el


[docs] @read_structure(r".*\.cif", preset_kwargs={"extract_structures": True}) def read_file( file_name, extract_structures=False, strct_check_chem_formula=True, strct_get_sym_op_from_sg=True, ): """ Read cif file. Parameters ---------- file_name : str Path to the cif file. extract_structures : bool (optional) Whether to extract alls crystal structures and add them to the output dictionary with the key ``'structures'``. strct_check_chem_formula : bool (optional) Check the chemical formula given by field matches with the structure. strct_get_sym_op_from_sg : bool (optional) Add symmetry operations based on the space group to add symmetry equivalent sites to the structures. Returns ------- dict Output dictionary. """ cif_blocks = [] current_block = None with open(file_name, "r") as f_obj: for line_idx, line in enumerate(f_obj): line = line.strip() if line.startswith("data_"): if current_block: current_block.finalize(line_idx) cif_blocks.append(current_block) current_block = _CIFDataBlock(line) elif current_block: current_block.add_line(line_idx, line) current_block.finalize(line_idx) cif_blocks.append(current_block) output_dict = {} structures = [] for block in cif_blocks: if block.title in output_dict: warn(f"Two data bloocks have the same title: '{block.title}'.", UserWarning) if extract_structures: if block.title == "structures": warn("Data block 'structures' is overwritten.", UserWarning) cell = block.get_cell() if cell is not None: elements, kinds, positions, site_attributes = block.get_atomic_sites( strct_check_chem_formula, strct_get_sym_op_from_sg ) structures.append( { "cell": cell, "label": block.title, "elements": elements, "kinds": kinds, "site_attributes": site_attributes, "positions": positions, "pbc": True, "is_cartesian": False, } ) output_dict[block.title] = block.get_output() if extract_structures: output_dict["structures"] = structures return output_dict