Source code for aim2dat.plots.partial_rdf

"""Plot class for radial distribution functions."""

# Third party library imports
import numpy as np

# Internal library imports:
from aim2dat.plots.base_plot import _BasePlot
from aim2dat.utils.units import UnitConverter


[docs] class PartialRDFPlot(_BasePlot): """ Plot the partial radial distribution function. """ def __init__(self, custom_linestyles=["solid", "dashed", "dotted", "dashdot"], **kwargs): """Initialize object.""" _BasePlot.__init__(self, custom_linestyles=custom_linestyles, **kwargs) self._el_pairs_color_indices = {} self._x_unit = None @property def x_unit(self): """Set unit of the x coordinate.""" return self._x_unit @x_unit.setter def x_unit(self, value): value = value.lower() if value not in UnitConverter.available_units: raise ValueError(f"{value} as unit not supported.") self._x_unit = value
[docs] def import_ffingerprint( self, data_label, bins, fingerprints, x_unit=None, ): """ Import F-Fingerprint functions. Parameters ---------- data_label : str Internal label used to plot and compare multiple data sets. bins : list Bins of the distance. fingerprints : dict Dictionary with the keys being the element pairs as tuples and the fingerprint functions as values. x_unit : str or None (optional) Unit of the x-axis. """ self._check_data_label(data_label) data_sets = [] coord_factor, self._x_unit = self._set_unit_conv_factor(x_unit, self._x_unit) bins = np.array(bins) * coord_factor for el_pair, fingerprint in fingerprints.items(): if el_pair not in self._el_pairs_color_indices: self._el_pairs_color_indices[el_pair] = ( max(list(self._el_pairs_color_indices.values()) + [-1]) + 1 ) data_sets.append( { "x_values": bins, "y_values": fingerprint, "type": "scatter", "color": self._el_pairs_color_indices[el_pair], "legendgrouptitle_text": el_pair[0] + "-" + el_pair[1], "legendgroup": el_pair[0] + "-" + el_pair[1], "label": data_label, "group_by": "color", } ) self._data[data_label] = data_sets
def _prepare_to_plot(self, data_labels, subplot_assignment): data_sets_plot = [[] for idx0 in range(max(subplot_assignment) + 1)] for idx, (data_label, subp_a) in enumerate(zip(data_labels, subplot_assignment)): data_sets = self._return_data_set(data_label) for data_set in data_sets: data_set["linestyle"] = idx if len(data_labels) == 1: data_set["label"] = data_set["legendgroup"] del data_set["legendgroup"] del data_set["group_by"] del data_set["legendgrouptitle_text"] data_sets_plot[subp_a] += data_sets x_label = f"Distance ({self._x_unit})" if self._x_unit in ["angstrom", "ang"]: x_label = r"Distance ($\mathrm{\AA}$)" self._auto_set_axis_properties(x_label=x_label) return data_sets_plot, None, None, None, None, None @staticmethod def _set_unit_conv_factor(input_unit, plot_unit): """Set unit conversion factor.""" conv_factor = 1.0 if plot_unit is None: plot_unit = input_unit elif input_unit is not None: conv_factor = UnitConverter.convert_units(1.0, input_unit, plot_unit) return conv_factor, plot_unit