Source code for aim2dat.plots.phase

"""
Module plotting quantities with respect to the chemical composition.
"""

# Standard library imports
import re
import math

# from statistics import mean

# Third party library imports
import numpy as np

# Internal library imports
from aim2dat.plots.base_plot import _BasePlot
from aim2dat.ext_interfaces.import_opt_dependencies import _return_ext_interface_modules
from aim2dat.fct.hull import get_convex_hull, get_minimum_maximum_points
import aim2dat.utils.chem_formula as utils_cf
import aim2dat.utils.space_groups as utils_sg
from aim2dat.utils.element_properties import get_element_symbol


[docs] def PhaseDiagram(*args, **kwargs): """Depreciated PhaseDiagram class.""" from warnings import warn warn( "This class will be removed, please use `PhasePlot` instead.", DeprecationWarning, 2, ) return PhasePlot(*args, **kwargs)
def _get_concentration(entry, elements): conc = None if all([el in elements for el in entry["chem_formula"]]): conc = 0.0 if elements[0] in entry["chem_formula"]: conc = entry["chem_formula"][elements[0]] / sum(entry["chem_formula"].values()) return conc
[docs] class PhasePlot(_BasePlot): """ Plot the formation energy of binary and ternary material systems. Attributes ---------- show_convex_hull : bool Whether to calculate and show the convex hull in the plot. """ _crystal_system_mapping = { "triclinic": 0, "monoclinic": 1, "orthorhombic": 2, "tetragonal": 3, "trigonal": 4, "hexagonal": 5, "cubic": 6, } _supported_plot_types = ["scatter", "numbers"] _default_y_labels = { "formation_energy": r"$E_{form}$ in eV/atom", "stability": "Stability in eV/atom", "numbers": "Nr. of structures", } def __init__( self, plot_type="scatter", plot_property="formation_energy", show_crystal_system=False, show_convex_hull=True, show_lower_hull=False, show_upper_hull=False, top_labels=[], hist_bin_size=0.1, **kwargs, ): """Initialize object.""" _BasePlot.__init__(self, **kwargs) self.plot_type = plot_type self.plot_property = plot_property self.show_crystal_system = show_crystal_system self.show_convex_hull = show_convex_hull self.show_lower_hull = show_lower_hull self.show_upper_hull = show_upper_hull self.top_labels = top_labels self.hist_bin_size = hist_bin_size self._all_elements = [] self._elements = None @property def elements(self): """ list: List of elements that are included in the plot. If set to ``None`` all elements are included. """ return self._elements @elements.setter def elements(self, value): if not isinstance(value, (list, tuple)): raise TypeError("`elements` needs to be of type list or tuple.") elements = [] for val0 in value: elements.append(get_element_symbol(val0)) self._elements = elements @property def plot_type(self): """ Specify plot type. Supported options are: ``'formation_energy'``, ``'stability'``, ``'band_gap'``, ``'direct_band_gap'`` and ``'numbers'``. """ return self._plot_type @plot_type.setter def plot_type(self, value): if value not in self._supported_plot_types: raise ValueError( f"`plot_type` '{value}' is not suppported. Supported options are '" + "', '".join(self._supported_plot_types) + "'." ) self._plot_type = value @property def show_crystal_system(self): """ Show crystal system of the phases. """ return self._show_crystal_system @show_crystal_system.setter def show_crystal_system(self, value): self._show_crystal_system = value @property def top_labels(self): """ list or str or dict: Chemical formulas that are shown as labels in the plot. """ return self._top_labels @top_labels.setter def top_labels(self, value): if not isinstance(value, (list, tuple)): value = [value] for f_idx, formula in enumerate(value): if isinstance(formula, str): value[f_idx] = utils_cf.transform_str_to_dict(formula) elif isinstance(formula, (list, tuple)): value[f_idx] = utils_cf.transform_list_to_dict(formula) self._top_labels = value
[docs] def add_data_point( self, data_label, formula, formation_energy=None, stability=None, unit=None, space_group=None, attributes=None, ): """ Add datapoint to the dataset. If the ``data_label`` does not exist, a new data set with label ``data_label`` is created. Parameters ---------- data_label : str Internal label used to plot and compare multiple data sets. formula : dict Chemical formula of the material, e.g. ``{'Cs': 1, 'Sb': 2}``. formation_energy : float (optional) Formation energy of the material. stability : float (optional) Stability of the material. unit : str (optional) Unit of the formation energy and stability. space_group : str or int (optional) Space group of the material, as symbol or number. The default value is ``None``. attributes : dict (optional) Additional attributes of the material that can be plotted. """ if isinstance(formula, (list, tuple)): formula = utils_cf.transform_list_to_dict(formula) elif isinstance(formula, str): formula = utils_cf.transform_str_to_dict(formula) elif isinstance(formula, dict): pass else: raise TypeError("`formula` needs to be of type list/tuple/str/dict.") entry = {"chem_formula": formula, "attributes": {}} if formation_energy is not None: entry["attributes"]["formation_energy"] = {"value": formation_energy, "unit": unit} if stability is not None: entry["attributes"]["stability"] = {"value": stability, "unit": unit} if space_group is None: entry["space_group"] = None else: entry["space_group"] = utils_sg.transform_to_nr(space_group) if attributes is not None: if not isinstance(attributes, dict): raise TypeError("`attributes` needs to be of type dict.") for attr_key, attr_val in attributes.items(): if attr_key not in entry["attributes"]: entry["attributes"][attr_key] = attr_val if data_label in self._data: self._data[data_label].append(entry) else: self._data[data_label] = [entry] for el in formula.keys(): if el not in self._all_elements: self._all_elements.append(el)
[docs] def import_from_structure_collection(self, data_label, structure_collection): """ Import data from a StructureCollection object. Parameters ---------- data_label : str Internal label used to plot and compare multiple data sets. structure_collection : aim2dat.strct.StructureCollection Instance of StructureCollection containing all structures. """ for structure in structure_collection: dp_kwargs = {} for attr in ["formation_energy", "stability"]: value = structure["attributes"].pop(attr, None) dp_kwargs[attr] = value if isinstance(value, dict): dp_kwargs[attr] = value["value"] if "unit" in value and "unit" not in dp_kwargs: dp_kwargs["unit"] = value["unit"] dp_kwargs["space_group"] = structure["attributes"].pop("space_group", None) dp_kwargs["attributes"] = structure["attributes"] self.add_data_point( data_label, utils_cf.transform_list_to_dict(structure["elements"]), **dp_kwargs )
[docs] def import_from_pandas_df( self, data_label, data_frame, structure_column="optimized_structure" ): """ Import data from pandas data frame. Parameters ---------- data_label : str Internal label used to plot and compare multiple data sets. data_frame : pandas.DataFrame Pandas data frame containing the total energy or formation energy and the structural details. structure_column : str (optional) Column containing AiiDA structure nodes used to determine structural and compositional properties. The default value is ``'optimized_structure'``. """ self._check_data_label(data_label) pattern = re.compile(r"([\w-]+)?\s*\(?(\w+)?\)?") attributes = {} comp_type = None comp_cols = {} sg_col = None for col in data_frame.columns: col_splitted = col.split("_") if structure_column == col: comp_type = "aiida_structure" comp_cols = col elif "chem_formula" in col and comp_type is None: comp_type = "chem_formula" comp_cols = col elif ( "nr_atoms" in col and len(col_splitted) > 2 and (comp_type is None or comp_type == "atoms_per_el") ): symbol = col_splitted[-1] comp_type = "atoms_per_el" comp_cols[symbol] = col elif "space_group" in col: sg_col = col else: found = pattern.findall(col)[0] if len(found) > 1: label, unit = found else: unit = None label = found[0] attributes[label] = {"value": None, "unit": unit, "col": col} for _, row in data_frame.iterrows(): chem_f = getattr(self, "_extract_formula_from_" + comp_type)(row, comp_cols) row_attr = {} for attr_label, attr_details in attributes.items(): row_attr[attr_label] = { "value": row[attr_details["col"]], "unit": attr_details["unit"], } self.add_data_point(data_label, chem_f, attributes=row_attr, space_group=row[sg_col])
def _prepare_to_plot(self, data_labels, subplot_assignment): plot_elements = self._elements if plot_elements is None: plot_elements = self._all_elements if len(plot_elements) != 2: raise NotImplementedError("Feature is not yet supported.") # TODO add y-label? self._auto_set_axis_properties( x_label=r"$x_{" + plot_elements[0] + r"}$", ) data_sets = [] label_tl = [] values_tl = [] for top_label in self._top_labels: if all([el in plot_elements for el in top_label.keys()]): x_value = 0.0 if plot_elements[0] in top_label: x_value = top_label[plot_elements[0]] / sum(top_label.values()) label_tl.append(utils_cf.transform_dict_to_latexstr(top_label)) values_tl.append(x_value) data_sets.append( { "x": x_value, "color": "black", "linestyle": "dashed", "scaled": True, "type": "vline", } ) create_function = getattr(self, "_create_2d_" + self.plot_type + "_data_sets") for data_set_idx, data_label in enumerate(data_labels): create_function(data_sets, data_set_idx, data_label, plot_elements) return ( data_sets, None, None, None, None, [{"ticks": values_tl, "tick_labels": label_tl, "coord": "x"}], ) @staticmethod def _extract_formula_from_aiida_structure(row, comp_cols): """Extract chemical formula from aiida structure nodes.""" backend_module = _return_ext_interface_modules("aiida") struct = backend_module._load_data_node(row[comp_cols]) return utils_cf.transform_str_to_dict(struct.get_formula()) @staticmethod def _extract_formula_from_atoms_per_el(row, comp_cols): """Extract chemical formula from atoms-per-element lists.""" chem_f = {} for symbol, col in comp_cols.items(): chem_f[symbol] = row[col] return chem_f @staticmethod def _extract_formula_from_chem_formula(row, comp_cols): """Create entries list from list of chemical formulas.""" return row[comp_cols] def _create_2d_scatter_data_sets(self, data_sets, ds_idx, data_label, elements): """Process entries for 2D phase diagram.""" cs_map = {"Phases": 0} if self.show_crystal_system: cs_map = self._crystal_system_mapping entries = self._return_data_set(data_label) new_data_sets = [] space_groups = [] used_labels = [] data_points = [] for entry in entries: conc = _get_concentration(entry, elements) if conc is None: continue if self.plot_property not in entry["attributes"]: print( f"Property '{self.plot_property}' missing, could not plot entry " + utils_cf.transform_dict_to_str(entry["chem_formula"]) + "." ) continue y_value = entry["attributes"][self.plot_property] if isinstance(y_value, dict): y_value = y_value["value"] data_points.append((conc, y_value)) if self.show_crystal_system and entry["space_group"] is None: print( "Space group missing, could not plot entry " + utils_cf.transform_dict_to_str(entry["chem_formula"]) + "." ) continue crystal_system = "Phases" if self.show_crystal_system: space_groups.append(entry["space_group"]) crystal_system = utils_sg.get_crystal_system(entry["space_group"]) data_set = { "x_values": [conc], "y_values": [y_value], "marker": cs_map[crystal_system], "linestyle": "none", "markerfacecolor": "none", "markeredgewidth": 1.7, "color": ds_idx, "legendgrouptitle_text": data_label, "legendgroup": data_label, "group_by": "color", } if crystal_system not in used_labels: data_set["label"] = crystal_system used_labels.append(crystal_system) new_data_sets.append(data_set) if self.show_crystal_system: zipped = list(zip(space_groups, new_data_sets)) zipped.sort(key=lambda point: point[0]) _, new_data_sets = zip(*zipped) new_data_sets = list(new_data_sets) # create_hulls show_hull = {} for hull_type in ["convex_hull", "lower_hull", "upper_hull"]: hull_attr = getattr(self, "show_" + hull_type) if not isinstance(hull_attr, bool): if len(hull_attr) > ds_idx: hull_attr = hull_attr[ds_idx] else: hull_attr = hull_attr[-1] show_hull[hull_type] = hull_attr if show_hull["convex_hull"]: x_values, y_values = get_convex_hull(data_points, upper_hull=False) new_data_sets.append( { "x_values": list(x_values), "y_values": list(y_values), "color": ds_idx, "legendgrouptitle_text": data_label, "legendgroup": data_label, "label": "Convex hull", "group_by": "color", } ) if show_hull["lower_hull"] or show_hull["upper_hull"]: x_values, min_values, max_values = get_minimum_maximum_points(data_points) if show_hull["lower_hull"] and show_hull["upper_hull"]: new_data_sets.append( { "x_values": x_values, "y_values": min_values, "y_values_2": max_values, "color": ds_idx, "alpha": ds_idx, "use_fill_between": True, } ) else: if show_hull["lower_hull"]: y_values = min_values else: y_values = max_values new_data_sets.append( { "x_values": x_values, "y_values": y_values, "color": ds_idx, } ) data_sets += new_data_sets def _create_2d_numbers_data_sets(self, data_sets, _, data_label, elements): """Process data for bar plot.""" cs_map = {"Phases": 0} if self.show_crystal_system: cs_map = self._crystal_system_mapping entries = self._return_data_set(data_label) x_values = [ float(val) for val in np.arange(0.0, 1.0 + self.hist_bin_size, self.hist_bin_size) ] for crystal_system, color_idx in cs_map.items(): data_sets.append( { "x_values": x_values, "bottom": [0.0] * len(x_values), "heights": [0.0] * len(x_values), "width": self.hist_bin_size - 0.005, "type": "bar", "color": color_idx, "label": crystal_system, } ) for entry in entries: if self.show_crystal_system and entry["space_group"] is None: print( "Space group missing, could not plot entry " + utils_cf.transform_dict_to_str(entry["chem_formula"]) + "." ) continue conc = _get_concentration(entry, elements) if conc is None: continue crystal_system = "Phases" cs_idx = 0 if self.show_crystal_system: crystal_system = utils_sg.get_crystal_system(entry["space_group"]) cs_idx = self._crystal_system_mapping[crystal_system] x_idx = math.ceil((conc - 0.5 * self.hist_bin_size) / self.hist_bin_size) ds_idx = len(data_sets) - len(cs_map) + cs_idx data_sets[ds_idx]["heights"][x_idx] += 1 for upper_idx in range(ds_idx + 1, len(data_sets)): data_sets[upper_idx]["bottom"][x_idx] += 1