Source code for metatrain.utils.data.dataset

import math
import os
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from metatensor.learn.data import Dataset, group_and_join
from metatensor.torch import TensorMap
from torch.utils.data import Subset

from ..external_naming import to_external_name
from ..units import get_gradient_units
from .target_info import TargetInfo


[docs] class DatasetInfo: """A class that contains information about datasets. This class is used to communicate additional dataset details to the training functions of the individual models. :param length_unit: Unit of length used in the dataset. Examples are ``"angstrom"`` or ``"nanometer"``. :param atomic_types: List containing all integer atomic types present in the dataset. ``atomic_types`` will be stored as a sorted list of **unique** atomic types. :param targets: Information about targets in the dataset. """ def __init__( self, length_unit: str, atomic_types: List[int], targets: Dict[str, TargetInfo] ): self.length_unit = length_unit if length_unit is not None else "" self._atomic_types = set(atomic_types) self.targets = targets @property def atomic_types(self) -> List[int]: """Sorted list of unique integer atomic types.""" return sorted(self._atomic_types) @atomic_types.setter def atomic_types(self, value: List[int]): self._atomic_types = set(value) def __repr__(self): return ( f"DatasetInfo(length_unit={self.length_unit!r}, " f"atomic_types={self.atomic_types!r}, targets={self.targets!r})" ) def __eq__(self, other): if not isinstance(other, DatasetInfo): raise NotImplementedError( "Comparison between a DatasetInfo instance and a " f"{type(other).__name__} instance is not implemented." ) return ( self.length_unit == other.length_unit and self._atomic_types == other._atomic_types and self.targets == other.targets )
[docs] def copy(self) -> "DatasetInfo": """Return a shallow copy of the DatasetInfo.""" return DatasetInfo( length_unit=self.length_unit, atomic_types=self.atomic_types.copy(), targets=self.targets.copy(), )
[docs] def update(self, other: "DatasetInfo") -> None: """Update this instance with the union of itself and ``other``. :raises ValueError: If the ``length_units`` are different. """ if self.length_unit != other.length_unit: raise ValueError( "Can't update DatasetInfo with a different `length_unit`: " f"({self.length_unit} != {other.length_unit})" ) self.atomic_types = self.atomic_types + other.atomic_types intersecting_target_keys = self.targets.keys() & other.targets.keys() for key in intersecting_target_keys: if not self.targets[key].is_compatible_with(other.targets[key]): raise ValueError( f"Can't update DatasetInfo with different target information for " f"target '{key}': {self.targets[key]} is not compatible with " f"{other.targets[key]}. If the units, quantity and keys of the two " "targets are the same, this must be due to a mismatch in the " "internal metadata of the layout." ) self.targets.update(other.targets)
[docs] def union(self, other: "DatasetInfo") -> "DatasetInfo": """Return the union of this instance with ``other``.""" new = self.copy() new.update(other) return new
[docs] def get_stats(dataset: Union[Dataset, Subset], dataset_info: DatasetInfo) -> str: """Returns the statistics of a dataset or subset as a string.""" dataset_len = len(dataset) stats = f"Dataset containing {dataset_len} structures" if dataset_len == 0: return stats # target_names will be used to store names of the targets, # along with their gradients target_names = [] for key, tensor_map in dataset[0]._asdict().items(): if key == "system": continue target_names.append(key) gradients_list = tensor_map.block(0).gradients_list() for gradient in gradients_list: target_names.append(f"{key}_{gradient}_gradients") sums = {key: 0.0 for key in target_names} sums_of_squares = {key: 0.0 for key in target_names} n_elements = {key: 0 for key in target_names} for sample in dataset: for key in target_names: if "_gradients" not in key: # not a gradient tensors = [block.values for block in sample[key].blocks()] else: original_key = key.split("_")[0] gradient_name = key.replace(f"{original_key}_", "").replace( "_gradients", "" ) tensors = [ block.gradient(gradient_name).values for block in sample[original_key].blocks() ] sums[key] += sum(tensor.sum() for tensor in tensors) sums_of_squares[key] += sum((tensor**2).sum() for tensor in tensors) n_elements[key] += sum(tensor.numel() for tensor in tensors) means = {key: sums[key] / n_elements[key] for key in target_names} means_of_squares = { key: sums_of_squares[key] / n_elements[key] for key in target_names } stds = { key: (means_of_squares[key] - means[key] ** 2) ** 0.5 for key in target_names } # Find units units = {} for key in target_names: # Gets the units of an output if key.endswith("_gradients"): # handling <base_name>_<gradient_name>_gradients base_name = key[:-10] gradient_name = base_name.split("_")[-1] base_name = base_name.replace(f"_{gradient_name}", "") base_unit = dataset_info.targets[base_name].unit unit = get_gradient_units( base_unit, gradient_name, dataset_info.length_unit ) else: unit = dataset_info.targets[key].unit units[key] = unit stats += "\n Mean and standard deviation of targets:" for key in target_names: stats += ( f"\n - {to_external_name(key, dataset_info.targets)}: " # type: ignore + f"\n - mean {means[key]:.4g}" + (f" {units[key]}" if units[key] != "" else "") + f"\n - std {stds[key]:.4g}" + (f" {units[key]}" if units[key] != "" else "") ) return stats
[docs] def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> List[int]: """List of all atomic types present in a dataset or list of datasets. :param datasets: the dataset, or list of datasets :returns: sorted list of all atomic types present in the datasets """ if not isinstance(datasets, list): datasets = [datasets] types = set() for dataset in datasets: for index in range(len(dataset)): system = dataset[index]["system"] types.update(set(system.types.tolist())) return sorted(types)
[docs] def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]: """Sorted list of all unique targets present in a dataset or list of datasets. :param datasets: the dataset(s). :returns: Sorted list of all targets present in the dataset(s). """ if not isinstance(datasets, list): datasets = [datasets] # The following does not work because the `dataset` can also # be a `Subset` object: # return list(dataset.targets.keys()) # Iterate over all single instances of the dataset: target_names = [] for dataset in datasets: for sample in dataset: # system not needed target_names += [key for key in sample._asdict().keys() if key != "system"] return sorted(set(target_names))
[docs] def collate_fn(batch: List[Dict[str, Any]]) -> Tuple[List, Dict[str, TensorMap]]: """ Wraps `group_and_join` to return the data fields as a list of systems, and a dictionary of nameed targets. """ collated_targets = group_and_join(batch, join_kwargs={"remove_tensor_name": True}) collated_targets = collated_targets._asdict() systems = collated_targets.pop("system") return systems, collated_targets
[docs] def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]): """Check that the training and validation sets are compatible with one another Although these checks will not fit all use cases, most models would be expected to be able to use this function. :param train_datasets: A list of training datasets to check. :param val_datasets: A list of validation datasets to check :raises TypeError: If the ``dtype`` within the datasets are inconsistent. :raises ValueError: If the `val_datasets` has a target that is not present in the ``train_datasets``. :raises ValueError: If the training or validation set contains chemical species or targets that are not present in the training set """ # Check that system `dtypes` are consistent within datasets desired_dtype = None for train_dataset in train_datasets: if len(train_dataset) == 0: continue actual_dtype = train_dataset[0].system.positions.dtype if desired_dtype is None: desired_dtype = actual_dtype if actual_dtype != desired_dtype: raise TypeError( "`dtype` between datasets is inconsistent, " f"found {desired_dtype} and {actual_dtype} in training datasets" ) for val_dataset in val_datasets: if len(val_dataset) == 0: continue actual_dtype = val_dataset[0].system.positions.dtype if desired_dtype is None: desired_dtype = actual_dtype if actual_dtype != desired_dtype: raise TypeError( "`dtype` between datasets is inconsistent, " f"found {desired_dtype} and {actual_dtype} in validation datasets" ) # Get all targets in the training and validation sets: train_targets = get_all_targets(train_datasets) val_targets = get_all_targets(val_datasets) # Check that the validation sets do not have targets that are not in the # training sets: for target in val_targets: if target not in train_targets: raise ValueError( f"The validation dataset has a target ({target}) that is not present " "in the training dataset." ) # Get all the species in the training and validation sets: all_train_species = get_atomic_types(train_datasets) all_val_species = get_atomic_types(val_datasets) # Check that the validation sets do not have species that are not in the # training sets: for species in all_val_species: if species not in all_train_species: raise ValueError( f"The validation dataset has a species ({species}) that is not in the " "training dataset. This could be a result of a random train/validation " "split. You can avoid this by providing a validation dataset manually." )
def _train_test_random_split( train_dataset: Dataset, train_size: float, test_size: float, ) -> List[Dataset]: if train_size <= 0: raise ValueError("Fraction of the train set is smaller or equal to 0!") # normalize the sizes size_sum = train_size + test_size train_size /= size_sum test_size /= size_sum # find number of samples in the train and test sets test_len = math.floor(len(train_dataset) * test_size) if test_len == 0: warnings.warn( "Requested dataset of zero length. This dataset will be empty.", UserWarning, stacklevel=2, ) train_len = len(train_dataset) - test_len if train_len == 0: raise ValueError("No samples left in the training set.") # find train, test indices indices = list(range(len(train_dataset))) np.random.shuffle(indices) train_indices = indices[:train_len] test_indices = indices[train_len:] return [ Subset(train_dataset, train_indices), Subset(train_dataset, test_indices), ] def _save_indices( train_indices: List[Optional[List[int]]], val_indices: List[Optional[List[int]]], test_indices: List[Optional[List[int]]], checkpoint_dir: Union[str, Path], ) -> None: # Save the indices of the training, validation, and test sets to the checkpoint # directory. This is useful for plotting errors and similar. # case 1: all indices are None (i.e. all datasets were user-provided explicitly) if all(indices is None for indices in train_indices): pass # case 2: there is only one dataset elif len(train_indices) == 1: # val and test are the same length os.mkdir(os.path.join(checkpoint_dir, "indices/")) if train_indices[0] is not None: np.savetxt( os.path.join(checkpoint_dir, "indices/training.txt"), train_indices[0], fmt="%d", ) if val_indices[0] is not None: np.savetxt( os.path.join(checkpoint_dir, "indices/validation.txt"), val_indices[0], fmt="%d", ) if test_indices[0] is not None: np.savetxt( os.path.join(checkpoint_dir, "indices/test.txt"), test_indices[0], fmt="%d", ) # case 3: there are multiple datasets else: os.mkdir(os.path.join(checkpoint_dir, "indices/")) for i, (train, val, test) in enumerate( zip(train_indices, val_indices, test_indices) ): if train is not None: np.savetxt( os.path.join(checkpoint_dir, f"indices/training_{i}.txt"), train, fmt="%d", ) if val is not None: np.savetxt( os.path.join(checkpoint_dir, f"indices/validation_{i}.txt"), val, fmt="%d", ) if test is not None: np.savetxt( os.path.join(checkpoint_dir, f"indices/test_{i}.txt"), test, fmt="%d", )