Source code for aim2dat.strct.ext_analysis.graphs

"""Methods to create graphs from molecules and crystals."""

# Standard library imports
from typing import List

# Third party library imports
import networkx as nx

# Internal library imports
from aim2dat.ext_interfaces import _return_ext_interface_modules
from aim2dat.strct.strct import Structure
from aim2dat.strct.ext_analysis.decorator import external_analysis_method


[docs] @external_analysis_method def create_graph( structure: Structure, get_graphviz_graph: bool = False, graphviz_engine: str = "circo", graphviz_edge_rank_colors: List[str] = ["blue", "red", "green", "orange", "darkblue"], **cn_kwargs, ): """ Create graph based on the coordination. Parameters ---------- structure : aim2dat.strct.Structure Structure object. get_graphviz_graph : bool Whether to also output a graphviz.Digraph object. graphviz_engine : str Graphviz engine used to create the graph. The default value is ``'circo'``. graphviz_edge_rank_colors : list List of colors of the different edge ranks. cn_kwargs : Optional keyword arguments passed on to the ``calculate_coordination`` function. Returns ------- nx_graph : nx.MultiDiGraph networkx graph of the structure. graphviz_graph : graphviz.Digraph graphviz graph of the structure (if ``get_graphviz_graph`` is set to ``True``). """ coord = structure.calculate_coordination(**cn_kwargs) nx_graph = nx.MultiDiGraph() for site_idx, site in enumerate(coord["sites"]): nx_graph.add_node(site_idx, element=site["element"]) for site_idx, site in enumerate(coord["sites"]): if len(site["neighbours"]) == 0: continue distances = [neigh["distance"] for neigh in site["neighbours"]] zipped = list(zip(distances, range(len(site["neighbours"])))) zipped.sort(key=lambda point: point[0]) _, neigh_indices = zip(*zipped) last_dist = 0.0 last_dist_idx = 0 for dist_idx, neigh_idx in enumerate(neigh_indices): if abs(last_dist - distances[neigh_idx]) < 1e-5: dist_idx = last_dist_idx nx_graph.add_edge(site_idx, site["neighbours"][neigh_idx]["site_index"], rank=dist_idx) last_dist_idx = dist_idx last_dist = distances[neigh_idx] if get_graphviz_graph: backend_module = _return_ext_interface_modules("graphviz") return None, ( nx_graph, backend_module._networkx2graphviz( nx_graph, graphviz_engine, graphviz_edge_rank_colors ), ) else: return None, nx_graph