"""Helper functions for machine learning tasks."""
# Standard library imports
from random import shuffle
import math
# Third party library imports
import numpy as np
from sklearn.model_selection import train_test_split
# Internal library imports
from aim2dat.utils.chem_formula import transform_list_to_dict
from aim2dat.strct import StructureCollection
def _get_all_elements(X, distinguish_kinds):
"""Create element or kind pairs."""
el_type = "elements"
if distinguish_kinds:
el_type = "kinds"
all_elements = []
for strct in X:
all_elements += strct[el_type]
return sorted(set(all_elements))
def _retrieve_target_value(structure, target_attribute):
"""Retrieve attribute from structure dictionary."""
if target_attribute not in structure["attributes"]:
raise ValueError(
f"Target 'target_attribute' not available for structure '{structure['label']}'."
)
if (
isinstance(structure["attributes"][target_attribute], dict)
and "value" in structure["attributes"][target_attribute]
):
return structure["attributes"][target_attribute]["value"]
else:
return structure["attributes"][target_attribute]
def _remove_structures(structure_list, exclude_labels):
"""Remove structures from list."""
if len(exclude_labels) > 0:
ind2del = []
for idx, strct in enumerate(structure_list):
if strct["label"] in exclude_labels:
ind2del.append(idx)
ind2del.sort(reverse=True)
for idx in ind2del:
del structure_list[idx]
def _check_train_test_size(n_structures, train_size, test_size):
"""Chcek sizes of training and test subset."""
def _get_n_dataset(dataset_size, n_structures, round_type):
n_ds = None
if dataset_size is not None:
if dataset_size > 1.0:
n_ds = int(dataset_size)
else:
n_ds = getattr(math, round_type)(dataset_size * n_structures)
return n_ds
n_train = _get_n_dataset(train_size, n_structures, "floor")
n_test = _get_n_dataset(test_size, n_structures, "ceil")
if n_train is None and n_test is None:
raise ValueError("`train_size` or `test_size` need to be set.")
if n_test is None:
n_test = n_structures - n_train
if n_train is None:
n_train = n_structures - n_test
if n_structures < n_train + n_test:
raise ValueError(
"`train_size`+`test_size` need to be smaller than the overall dataset size."
)
return n_train, n_test
def _build_stratified_subset(subset_size, strct_list, hist_data, used_indices):
"""Create a strafied subset based on the target attribute and/or the composition."""
hist_subset = {key: np.zeros(len(val[1])) for key, val in hist_data.items()}
subset = []
target = []
for idx, strct in enumerate(strct_list):
if idx in used_indices:
continue
add2subset = True
bin_indices = {}
for key, hist in hist_data.items():
value = hist[0][idx]
for bin_idx, bin_e in enumerate(hist[2][1:]):
if value < bin_e:
break
if hist_subset[key][bin_idx] < math.floor(subset_size * hist[1][bin_idx]):
bin_indices[key] = bin_idx
else:
add2subset = False
break
if add2subset:
for key, val in bin_indices.items():
hist_subset[key][val] += 1
subset.append(strct)
target.append(hist_data["target"][0][idx])
used_indices.append(idx)
if len(subset) == subset_size:
break
return subset, target
[docs]
def train_test_split_crystals(
structure_collection,
target_attribute,
train_size=None,
test_size=None,
target_bins=None,
composition_bins=None,
elements=None,
exclude_labels=[],
return_structure_collections=False,
):
"""
Split dataset of crystals into a training and test dataset. The target attribute and/or the
composition can be strafied based on binning.
Parameters
----------
structure_collection : aim2dat.strct.StructureCollection
``StructureCollection'' containing the crystals.
target_attribute : str
Label of the target attribute.
train_size : float, int or None (optional)
Training set size.
test_size : float, int or None (optional)
Test set size.
target_bins : int or sequence of scalars or str or None (optional)
Input for np.histogram function. If set to ``None`` binning is not performed. If
``target_bins`` and ``composition_bins`` is set to ``None`` the ``train_test_split``
function of scikit learn is used.
composition_bins : int or sequence of scalars or str or None (optional)
Input for np.histogram function. If set to ``None`` binning is not performed. If
``target_bins`` and ``composition_bins`` is set to ``None`` the ``train_test_split``
function of scikit learn is used.
elements : list or None
Elements that are considered for composition binning. If set to ``None`` all elements are
taken into account.
exclude_labels : list
Structure labels that should be excluded from the train and test dataset.
return_structure_collections : bool
Whether to return the train and test dataset as ``StructureCollection`` objects.
Returns
-------
subset_train : list or StructureCollection
Training set returned as list or ``StructureCollection`` object.
subset_test : list or StructureCollection
Test set returned as list or ``StructureCollection`` object.
target_train : list
List of target values of the training set.
target_test : list
List of target values of the test set.
"""
strct_list = structure_collection.get_all_structures()
if target_bins is not None or composition_bins is not None:
shuffle(strct_list)
if elements is None:
all_elements = _get_all_elements(strct_list, False)
all_elements = all_elements[:-1]
else:
all_elements = elements
el_comps = {el: [] for el in all_elements}
target = []
for strct in strct_list:
chem_f = transform_list_to_dict(strct["elements"])
for el in all_elements: # , val in chem_f.items():
if el in chem_f:
el_comps[el].append(chem_f[el] / len(strct["elements"]))
else:
el_comps[el].append(0.0)
target.append(_retrieve_target_value(strct, target_attribute))
n_strct_list = len(strct_list)
_remove_structures(strct_list, exclude_labels)
n_train, n_test = _check_train_test_size(len(strct_list), train_size, test_size)
hist_data = {}
if target_bins is None:
target_bins = 1
hist, bin_edges = np.histogram(target, bins=target_bins)
hist = hist / n_strct_list
hist_data["target"] = (target, hist, bin_edges)
if composition_bins is None:
composition_bins = 1
for el, vals in el_comps.items():
hist, bin_edges = np.histogram(vals, bins=composition_bins)
hist = hist / n_strct_list
hist_data["el"] = (vals, hist, bin_edges)
used_indices = []
subset_train, target_train = _build_stratified_subset(
n_train, strct_list, hist_data, used_indices
)
subset_test, target_test = _build_stratified_subset(
n_test, strct_list, hist_data, used_indices
)
for subset, subset_t, n_subset in [
(subset_train, target_train, n_train),
(subset_test, target_test, n_test),
]:
idx0 = 0
while len(subset) < n_subset:
if idx0 not in used_indices:
used_indices.append(idx0)
subset.append(strct_list[idx0])
subset_t.append(target[idx0])
idx0 += 1
else:
_remove_structures(strct_list, exclude_labels)
target = [_retrieve_target_value(strct, target_attribute) for strct in strct_list]
subset_train, subset_test, target_train, target_test = train_test_split(
strct_list, target, train_size=train_size, test_size=test_size
)
if return_structure_collections:
return (
StructureCollection(subset_train),
StructureCollection(subset_test),
target_train,
target_test,
)
else:
return subset_train, subset_test, target_train, target_test