Source code for aim2dat.plots.band_structure_dos

"""
Module to plot the band structure and the density of states separately or combined.

This module is still work in progress.
"""

# Internal library imports
from aim2dat.plots.base_plot import _BasePlot
from aim2dat.plots.base_mixin import _HLineMixin, _VLineMixin
from aim2dat.plots.base_band_structure import _BaseBandStructure
from aim2dat.plots.base_dos import _BaseDensityOfStates


[docs] def BandStructure(*args, **kwargs): """Depreciated band structure class.""" from warnings import warn warn( "This class will be removed, please use `BandStructurePlot` instead.", DeprecationWarning, 2, ) return BandStructurePlot(*args, **kwargs)
[docs] def DensityOfStates(*args, **kwargs): """Depreciated DOS class.""" from warnings import warn warn( "This class will be removed, please use `DOSPlot` instead.", DeprecationWarning, 2, ) return DOSPlot(*args, **kwargs)
[docs] def BandStructureDensityOfStates(*args, **kwargs): """Depreciated band structure DOS class.""" from warnings import warn warn( "This class will be removed, please use `BandStructureDOSPlot` instead.", DeprecationWarning, 2, ) return BandStructureDOSPlot(*args, **kwargs)
[docs] class BandStructurePlot(_BasePlot, _HLineMixin, _VLineMixin, _BaseBandStructure): """ Plot the band structure of a crystalline structure. """ _object_title = "Band Structure Plot" def __init__(self, **kwargs): """Initialize class.""" _BasePlot.__init__(self, **kwargs) _BaseBandStructure.__init__(self) def _print_extra_properties(self): """Print extra properties specific for this plot.""" ref_cell_set = True if self._reciprocal_cell is None: ref_cell_set = False output_str = f" Reference cell set: {ref_cell_set}.\n" if ref_cell_set: output_str += " Rec. cell vectors: " for vector_idx in range(3): for coord_idx in range(3): output_str += ( "{:.5f}".format(self._reciprocal_cell[vector_idx][coord_idx]) + " " ) if vector_idx < 2: output_str += "\n " output_str += "\n" output_str += "\n" if self._path_specifications is not None: output_str += " Path segments:\n" for segment_idx in range(len(self._path_specifications)): segment = self._path_specifications[segment_idx] segment_str = f" {segment_idx + 1}. Segment: " output_str += segment_str for coord_idx in range(3): output_str += "{:.5f}".format(segment["kpoint_st"][coord_idx]) + " " output_str += segment["label_st"] + "\n" output_str += " ".join(["" for idx in range(len(segment_str))]) + " " for coord_idx in range(3): output_str += "{:.5f}".format(segment["kpoint_e"][coord_idx]) + " " output_str += segment["label_e"] + "\n" return output_str def _prepare_to_plot(self, data_labels, subplot_assignment): if not self.x_range: self.x_range = [0.0, self._path_specifications[-1]["pos_e"]] general_data_sets = [] path_ticks, path_labels = self._process_path_labels() for x_val in path_ticks: general_data_sets.append( {"x": x_val, "color": "black", "linestyle": "-", "scaled": True, "type": "vline"} ) general_data_sets.append( { "y": 0.0, "xmin": 0.0, "xmax": self._path_specifications[-1]["pos_e"], "color": "black", "linestyle": "dashed", "scaled": False, "type": "hline", } ) data_sets = [general_data_sets.copy() for idx0 in range(max(subplot_assignment) + 1)] for idx, (data_label, subp_a) in enumerate(zip(data_labels, subplot_assignment)): data_sets[subp_a] += self._generate_data_sets_for_plot(idx, data_label) self._auto_set_axis_properties(y_label="Energy in eV") return data_sets, path_ticks, None, path_labels, None, None
[docs] class DOSPlot(_BasePlot, _VLineMixin, _BaseDensityOfStates): """ Class to plot the density of states. Attributes ---------- dos_comp_threshold : float Threshold to compare the density of states if ``detect_equivalent_kinds`` is set to ``True`` when importing projected density of states data sets. sum_pdos : bool Whether to sum all pDOS data sets to obtain a tDOS. per_atom : bool Normalize all density of states data sets to the numer of atoms. """ _object_title = "DOS Plot" def __init__( self, ratio=(10, 4), pdos_plot_type="line", tdos_plot_type="fill", dos_comp_threshold=0.47, smearing_method="gaussian", smearing_delta=0.005, smearing_sigma=5.0, sum_pdos=False, per_atom=False, **kwargs, ): """Initialize class.""" _BasePlot.__init__(self, ratio=ratio, **kwargs) _BaseDensityOfStates.__init__( self, pdos_plot_type, tdos_plot_type, dos_comp_threshold, smearing_method, smearing_delta, smearing_sigma, sum_pdos, per_atom, ) def _prepare_to_plot(self, data_labels, subplot_assignment): pdos_all = [[] for idx0 in range(max(subplot_assignment) + 1)] tdos_all = [[] for idx0 in range(max(subplot_assignment) + 1)] color_map = {} ls_map = {} for idx, (data_label, subp_a) in enumerate(zip(data_labels, subplot_assignment)): dos_data = self._return_data_set(data_label) nr_atoms = 1 y_label = "DOS in states/eV/cell" if self.per_atom and "chem_formula" in dos_data: nr_atoms = sum([el_quant for el_quant in dos_data["chem_formula"].values()]) y_label = "DOS in states/eV/atom" # Generate pDOS data sets: if "pdos" in dos_data and self.sum_pdos: dos_data["tdos"] = self._sum_pdos(dos_data["pdos"]) if "tdos" in dos_data: tdos_all[subp_a] += self._dos_create_plot_data_set( color_map, ls_map, dos_data["tdos"], 1.0, self.tdos_plot_type, 0 ) if "pdos" in dos_data: pdos_all[subp_a] += self._dos_create_plot_data_set( color_map, ls_map, dos_data["pdos"], nr_atoms, self.pdos_plot_type, 1 ) data_sets = [] for pdos_ds, tdos_ds in zip(pdos_all, tdos_all): data_sets.append( [{"x": 0.0, "color": "black", "linestyle": "--", "scaled": True, "type": "vline"}] + tdos_ds + pdos_ds ) self._auto_set_axis_properties(x_label="Energy in eV", y_label=y_label) return data_sets, None, None, None, None, None
[docs] class BandStructureDOSPlot( _BasePlot, _HLineMixin, _VLineMixin, _BaseBandStructure, _BaseDensityOfStates ): """ Class to plot the band structure combined with the density of states. Attributes ---------- dos_comp_threshold : float Threshold to compare the density of states if ``detect_equivalent_kinds`` is set to ``True`` when importing projected density of states data sets. sum_pdos : bool Whether to sum all pDOS data sets to obtain a tDOS. per_atom : bool Normalize all density of states data sets to the numer of atoms. """ def __init__( self, ratio=(12, 7), show_legend=[False, True], subplot_hspace=5, subplot_wspace=0.1, subplot_nrows=1, subplot_ncols=3, subplot_sharex=False, subplot_sharey=True, subplot_gridspec=[(0, 1, 0, 2), (0, 1, 2, 3)], pdos_plot_type="line", tdos_plot_type="fill", dos_comp_threshold=0.47, smearing_method="gaussian", smearing_delta=0.005, smearing_sigma=5.0, sum_pdos=False, per_atom=False, **kwargs, ): """Initialize class.""" _BasePlot.__init__( self, ratio=ratio, show_legend=show_legend, subplot_hspace=subplot_hspace, subplot_wspace=subplot_wspace, subplot_nrows=subplot_nrows, subplot_ncols=subplot_ncols, subplot_sharex=subplot_sharex, subplot_sharey=subplot_sharey, subplot_gridspec=subplot_gridspec, **kwargs, ) _BaseBandStructure.__init__(self) _BaseDensityOfStates.__init__( self, pdos_plot_type, tdos_plot_type, dos_comp_threshold, smearing_method, smearing_delta, smearing_sigma, sum_pdos, per_atom, )
[docs] def shift_bands_and_dos_to_vbm(self, data_label): """ Shift the bands and the density of states such that the VBM is zero. Parameters ---------- data_label : str Data label of the data set. """ data_set = self._return_data_set(data_label, deepcopy=False) if "occupations" in data_set["band_structure"]: lowest_unocc_idx = self._check_occupations(data_set["band_structure"]["occupations"]) valence_band = data_set["band_structure"]["bands"][lowest_unocc_idx - 1] _, vbm = self._analyse_band(valence_band, data_set["band_structure"]["kpoints"]) data_set["band_structure"]["bands"] = self._shift_bands( data_set["band_structure"]["bands"], -1.0 * vbm["energy"] ) vbm_energy_max = vbm["energy"] # if "pdos" in data_set: # self.__shift_projected_dos__(data_set["pdos"], -1.0 * vbm["energy"]) # if "tdos" in data_set: # self.__shift_total_dos__(data_set["tdos"], -1.0 * vbm["energy"]) elif "occupations_spinup" in data_set["band_structure"]: vbm_energies = [] for spin in ["_spinup", "_spindown"]: lowest_unocc_idx = self._check_occupations( data_set["band_structure"]["occupations" + spin] ) valence_band = data_set["band_structure"]["bands" + spin][lowest_unocc_idx - 1] _, vbm = self._analyse_band(valence_band, data_set["band_structure"]["kpoints"]) vbm_energies.append(vbm["energy"]) vbm_energy_max = max(vbm_energies) for spin in ["_spinup", "_spindown"]: data_set["band_structure"]["bands" + spin] = self._shift_bands( data_set["band_structure"]["bands" + spin], -1.0 * vbm_energy_max ) self.shift_dos(data_label, -1.0 * vbm_energy_max)
[docs] def shift_bands_and_dos(self, data_label, shift): """ Shift band structure and density of states. TODO: include spin-polarized case. Parameters ---------- data_label : str Data label of the data set. shift : float Value to shift band structure and density of states. """ data_set = self._return_data_set(data_label, deepcopy=False) if "band_structure" in data_set: data_set["band_structure"]["bands"] = self._shift_bands( data_set["band_structure"]["bands"], shift ) self.shift_dos(data_label, shift)
def _prepare_to_plot(self, data_labels, subplot_assignment): # data_label = data_labels[0] # data_set = self._return_data_set(data_label) path_ticks, path_labels = self._process_path_labels() data_set_bands = [ { "y": 0.0, "color": "black", "linestyle": "dashed", "scaled": True, "type": "hline", } ] for x_val in path_ticks: data_set_bands.append( {"x": x_val, "color": "black", "linestyle": "-", "scaled": True, "type": "vline"} ) n_tdos = 0 n_pdos = 0 for idx, data_label in enumerate(data_labels): data_set_bands += self._generate_data_sets_for_plot(idx, data_label) dos_data = self._return_data_set(data_label) if "pdos" in dos_data: n_pdos += 1 if "tdos" in dos_data or ("pdos" in dos_data and self.sum_pdos): n_tdos += 1 data_set_dos = [ {"y": 0.0, "color": "black", "linestyle": "--", "scaled": True, "type": "hline"} ] color_map_pdos = {"_": i0 for i0 in range(len(data_labels))} color_map_tdos = {} ls_map = {} for idx, data_label in enumerate(data_labels): dos_label_suffix = ["", ""] if n_tdos > 1: dos_label_suffix[0] = " " + data_label if n_pdos > 1: dos_label_suffix[1] = " " + data_label dos_data = self._return_data_set(data_label) nr_atoms = 1 x_label = [None, "DOS in states/eV/cell"] if self.per_atom and "chem_formula" in dos_data: nr_atoms = sum([el_quant for el_quant in dos_data["chem_formula"].values()]) x_label[1] = "DOS in states/eV/atom" # Generate pDOS data sets: if "pdos" in dos_data and self.sum_pdos: dos_data["tdos"] = self._sum_pdos(dos_data["pdos"]) proc_dos_data = [] if "tdos" in dos_data: dos_labels = list(dos_data["tdos"]["labels"].keys()) for label in dos_labels: orb_labels = dos_data["tdos"]["labels"].pop(label) orb_keywords = list(orb_labels.keys()) for keyw in orb_keywords: orb_labels[(keyw[0] + dos_label_suffix[0], keyw[1])] = orb_labels.pop(keyw) dos_data["tdos"]["labels"][label + dos_label_suffix[0]] = orb_labels proc_dos_data += self._dos_create_plot_data_set( color_map_tdos, ls_map, dos_data["tdos"], 1.0, self.tdos_plot_type, 0 ) if "pdos" in dos_data: dos_labels = list(dos_data["pdos"]["labels"].keys()) for label in dos_labels: dos_data["pdos"]["labels"][label + dos_label_suffix[1]] = dos_data["pdos"][ "labels" ].pop(label) proc_dos_data += self._dos_create_plot_data_set( color_map_pdos, ls_map, dos_data["pdos"], nr_atoms, self.pdos_plot_type, 1 ) for data_set0 in proc_dos_data: new_ds = dict(data_set0) new_ds["x_values"] = data_set0["y_values"] new_ds["y_values"] = data_set0["x_values"] data_set_dos.append(new_ds) data_set_plot = [data_set_bands, data_set_dos] if self.x_range is None: self.x_range = [[0.0, self._path_specifications[-1]["pos_e"]], None] elif isinstance(self.x_range[0], (list, tuple)) or self.x_range[0] is None: self.x_range = [[0.0, self._path_specifications[-1]["pos_e"]]] + list(self.x_range[1:]) else: self.x_range = [[0.0, self._path_specifications[-1]["pos_e"]], self.x_range] self._auto_set_axis_properties(x_label=x_label, y_label=["Energy in eV", None]) return data_set_plot, [path_ticks, None], None, [path_labels, None], None, None