aim2dat.ml.utils

Helper functions for machine learning tasks.

Module Contents

Functions

train_test_split_crystals(structure_collection, ...[, ...])

Split dataset of crystals into a training and test dataset. The target attribute and/or the

aim2dat.ml.utils.train_test_split_crystals(structure_collection, target_attribute, train_size=None, test_size=None, target_bins=None, composition_bins=None, elements=None, exclude_labels=[], return_structure_collections=False)[source]

Split dataset of crystals into a training and test dataset. The target attribute and/or the composition can be strafied based on binning.

Parameters:
  • structure_collection (aim2dat.strct.StructureCollection) – ``StructureCollection’’ containing the crystals.

  • target_attribute (str) – Label of the target attribute.

  • train_size (float, int or None (optional)) – Training set size.

  • test_size (float, int or None (optional)) – Test set size.

  • target_bins (int or sequence of scalars or str or None (optional)) – Input for np.histogram function. If set to None binning is not performed. If target_bins and composition_bins is set to None the train_test_split function of scikit learn is used.

  • composition_bins (int or sequence of scalars or str or None (optional)) – Input for np.histogram function. If set to None binning is not performed. If target_bins and composition_bins is set to None the train_test_split function of scikit learn is used.

  • elements (list or None) – Elements that are considered for composition binning. If set to None all elements are taken into account.

  • exclude_labels (list) – Structure labels that should be excluded from the train and test dataset.

  • return_structure_collections (bool) – Whether to return the train and test dataset as StructureCollection objects.

Returns:

  • subset_train (list or StructureCollection) – Training set returned as list or StructureCollection object.

  • subset_test (list or StructureCollection) – Test set returned as list or StructureCollection object.

  • target_train (list) – List of target values of the training set.

  • target_test (list) – List of target values of the test set.