Source code for aim2dat.plots.planar_fields

"""Classes to plot planar fields."""

# Standard library imports
import copy

# Third party library imports
import numpy as np

# Internal library imports
from aim2dat.plots.base_plot import _BasePlot
from aim2dat.utils.units import UnitConverter


[docs] class PlanarFieldPlot(_BasePlot): """ Plot scalar planar fields. """ _supported_norms = {"symlog": "SymLogNorm", "log": "LogNorm"} _supported_plot_types = {"heatmap", "contour"} def __init__(self, show_x_label=True, show_y_label=True, **kwargs): """Initialize object.""" _BasePlot.__init__(self, **kwargs) self._coordinates_unit = None self._values_unit = None self._norm = None self._plot_type = "heatmap" self.show_x_label = show_x_label self.show_y_label = show_y_label self.contour_filled = True self.contour_levels = None self.color_map = "RdBu_r" self.vmin = None self.vmax = None self.linthresh = 0.01 @property def coordinates_unit(self): """Set unit of the two coordinates.""" return self._coordinates_unit @coordinates_unit.setter def coordinates_unit(self, value): value = value.lower() if value not in UnitConverter.available_units: raise ValueError(f"{value} as unit not supported.") self._coordinates_unit = value @property def values_unit(self): """Set unit of the z-values.""" return self._values_unit @values_unit.setter def values_unit(self, value): value = value.lower() if value not in UnitConverter.available_units: raise ValueError(f"{value} as unit not supported.") self._values_unit = value @property def norm(self): """Set norm of the z-values for matplotlib.""" return self._norm @norm.setter def norm(self, value): if value not in self._supported_norms.keys(): raise ValueError( f"{value} not supported. Supported norms are: " + ", ".join(self._supported_norms.keys()) + "." ) self._norm = value @property def plot_type(self): """Set plot-type.""" return self._plot_type @plot_type.setter def plot_type(self, value): if value in self._supported_plot_types: self._plot_type = value else: raise ValueError( f"Plot type {value} is not supported. Supported plot types are: " + ", ".join(self._supported_plot_types) + "." )
[docs] def import_field( self, data_label, coordinates, values, flip_lr=False, flip_ud=False, coordinates_unit=None, values_unit=None, text_labels=[], ): """ Import field. Parameters ---------- data_label : str Internal label used to plot and compare multiple data sets. coordinates : list Nested list of the coordinates. values : list List or nested list of the values. flip_lr : bool (optional) Whether to flip the field from left to right. flip_ud : bool (optional) Whether to flip the field from up to down. coordinates_unit : str (optional) Unit of coordinates. values_unit : str (optional) Unit of values. text_labels : list (optional) List of text labels. """ self._check_data_label(data_label) text_labels = copy.deepcopy(text_labels) # - Check units: coord_factor, self._coordinates_unit = self._set_unit_conv_factor( coordinates_unit, self._coordinates_unit ) val_factor, self._values_unit = self._set_unit_conv_factor(values_unit, self._values_unit) # - Create grid: coordinates = np.array(coordinates) * coord_factor x_values, y_values = zip(*coordinates) x_values = np.sort(np.unique(np.array(x_values))) x = x_values if flip_lr: x = x_values[::-1] x_max = np.max(x) for label in text_labels: label["x"] = x_max - label["x"] y_values = np.sort(np.unique(np.array(y_values))) y = y_values if flip_ud: y = y_values[::-1] y_max = np.max(y) for label in text_labels: label["y"] = y_max - label["y"] val_shape = (y.shape[0], x.shape[0]) is_vector_field = False if isinstance(values[0], (tuple, list)): val_shape = (y.shape[0], x.shape[0], len(values[0])) is_vector_field = True values_grid = np.zeros(val_shape) # - Distribute points on grid. for coord, value in zip(coordinates, values): x_idx = np.where(x == coord[0])[0][0] y_idx = np.where(y == coord[1])[0][0] # print(x_idx, y_idx, coord) if is_vector_field: for val_idx, val0 in enumerate(value): values_grid[y_idx][x_idx][val_idx] = val0 else: values_grid[y_idx][x_idx] = value self._data[data_label] = { "x_values": x_values, "y_values": y_values, "z_values": values_grid, "is_vector_field": is_vector_field, "vmin": min(values), "vmax": max(values), "text_labels": [self._set_text_label(t_label) for t_label in text_labels], }
[docs] def import_from_aiida_arraydata( self, data_label, planedata, flip_lr=False, flip_ud=False, values_unit=None, text_labels=[] ): """ Import from aiida array data. Parameters ---------- data_label : str Internal label used to plot and compare multiple data sets. flip_lr : bool (optional) Whether to flip the field from left to right. flip_ud : bool (optional) Whether to flip the field from up to down. values_unit : str (optional) Unit of values. text_labels : list (optional) List of text labels. """ from aim2dat.ext_interfaces.aiida import _load_data_node planedata = _load_data_node(planedata) self.import_field( data_label, planedata.get_array("coordinates"), planedata.get_array("values"), coordinates_unit=planedata.get_attribute("coordinates_unit", None), flip_lr=flip_lr, flip_ud=flip_ud, values_unit=values_unit, text_labels=text_labels, )
def _prepare_to_plot(self, data_labels, subplot_assignment): if self.backend == "plotly": print("Warning: logarithmic scales and vmin/vmax is not supported for this backend.") axis_label = self._set_axis_label() self._auto_set_axis_properties(x_label=axis_label, y_label=axis_label) data_sets = [[] for idx0 in range(max(subplot_assignment) + 1)] for data_label, subp_a in zip(data_labels, subplot_assignment): # for data_label in data_labels: data_sets[subp_a] += self._process_data_set_plot(data_label) return data_sets, None, None, None, None, None def _process_data_set_plot(self, data_label): data_set = self._return_data_set(data_label) if data_set.pop("is_vector_field"): print("Only scalar fields are supported so far.") return None if self.plot_type == "contour": data_set["filled"] = self.contour_filled if all(atr0 is not None for atr0 in [self.contour_levels, self.vmin, self.vmax]): data_set["levels"] = np.linspace( data_set["vmin"], data_set["vmax"], self.contour_levels ) # data_set["extend"] = "both" elif self.contour_levels is not None: data_set["levels"] = np.linspace(self.vmin, self.vmax, self.contour_levels) if self.norm is not None: if self.norm == "symlog": data_set["symlog_scale"] = True data_set["linthresh"] = self.linthresh data_set["base"] = 10 elif self.norm == "log": data_set["log_scale"] = True data_set["linthresh"] = self.linthresh data_set["base"] = 10 self._set_norm(data_set) text_labels = data_set.pop("text_labels", []) data_set["cmap"] = self.color_map data_set["type"] = self.plot_type return [data_set] + text_labels def _set_norm(self, data_set): """Set parameter for the matplotlib.colors.LogNorm or SymLogNorm.""" if self._norm is not None: data_set["norm_type"] = self._supported_norms[self._norm] data_set["norm_args"] = {} if self.vmin is None: data_set["norm_args"]["vmin"] = data_set["vmin"] else: data_set["norm_args"]["vmin"] = self.vmin if self.vmax is None: data_set["norm_args"]["vmax"] = data_set["vmax"] else: data_set["norm_args"]["vmax"] = self.vmax if self._norm == "symlog": data_set["norm_args"]["linthresh"] = self.linthresh if self._norm == "log": data_set["norm_args"]["vmin"] = max(1e-21, data_set["norm_args"]["vmin"]) data_set["norm_args"]["vmax"] = max(1e-21, data_set["norm_args"]["vmax"]) print(data_set["norm_args"]["vmax"], data_set["norm_args"]["vmin"]) del data_set["vmin"] del data_set["vmax"] def _set_axis_label(self): """Set axis labels.""" return ( UnitConverter._available_units[self.coordinates_unit].capitalize() + f" in {UnitConverter.plot_labels[self.coordinates_unit]}" ) @staticmethod def _set_text_label(text_label): """Set text label.""" if not isinstance(text_label, dict): raise TypeError("Text labels need to be of type dict.") for key0 in ["x", "y", "label"]: if key0 not in text_label: raise ValueError(f"Key '{key0}' needs to be in text label.") text_label["type"] = "text" text_label["ha"] = "center" text_label["va"] = "center" # text_label["weight"] = "bold" return text_label @staticmethod def _set_unit_conv_factor(input_unit, plot_unit): """Set unit conversion factor.""" conv_factor = 1.0 if plot_unit is None: plot_unit = input_unit elif input_unit is not None: conv_factor = UnitConverter.convert_units(1.0, input_unit, plot_unit) return conv_factor, plot_unit