Dataset

class metatensor.models.utils.data.dataset.Dataset(dict: Dict)[source]

Bases: object

A version of the metatensor.learn.Dataset class that allows for the use of mtm:: prefixes in the keys of the dictionary. See https://github.com/lab-cosmo/metatensor/issues/621.

It is important to note that, instead of named tuples, this class accepts and returns dictionaries.

Parameters:

dict (Dict) – A dictionary with the data to be stored in the dataset.

class metatensor.models.utils.data.dataset.TargetInfo(quantity: str, unit: str = '', per_atom: bool = False, gradients: ~typing.List[str] = <factory>)[source]

Bases: object

A class that contains information about a target.

Parameters:
  • quantity (str) – The quantity of the target.

  • unit (str) – The unit of the target.

  • per_atom (bool) – Whether the target is a per-atom quantity.

  • gradients (List[str]) – Gradients of the target that are defined in the current dataset.

quantity: str
unit: str = ''
per_atom: bool = False
gradients: List[str]
class metatensor.models.utils.data.dataset.DatasetInfo(length_unit: str, targets: Dict[str, TargetInfo])[source]

Bases: object

A class that contains information about one or more datasets.

This dataclass is used to communicate additional dataset details to the training functions of the individual models.

Parameters:
  • length_unit (str) – The unit of length used in the dataset.

  • targets (Dict[str, TargetInfo]) – The information about targets in the dataset.

length_unit: str
targets: Dict[str, TargetInfo]
metatensor.models.utils.data.dataset.get_all_species(datasets: Dataset | List[Dataset]) List[int][source]

Returns the list of all species present in a dataset or list of datasets.

Parameters:

datasets (Dataset | List[Dataset]) – the dataset, or list of datasets.

Returns:

The sorted list of species present in the datasets.

Return type:

List[int]

metatensor.models.utils.data.dataset.get_all_targets(datasets: Dataset | List[Dataset]) List[str][source]

Returns the list of all targets present in a dataset or list of datasets.

Parameters:

datasets (Dataset | List[Dataset]) – the dataset(s).

Returns:

list of targets present in the dataset(s), sorted according to the sort() method of Python lists.

Return type:

List[str]

metatensor.models.utils.data.dataset.collate_fn(batch: List[Dict[str, Any]]) Tuple[List, Dict[str, TensorMap]][source]

Wraps group_and_join to return the data fields as a list of systems, and a dictionary of nameed targets.

Parameters:

batch (List[Dict[str, Any]])

Return type:

Tuple[List, Dict[str, TensorMap]]

metatensor.models.utils.data.dataset.check_datasets(train_datasets: List[Dataset], validation_datasets: List[Dataset])[source]

Check that the training and validation sets are compatible with one another

Although these checks will not fit all use cases, most models would be expected to be able to use this function.

Parameters:
  • train_datasets (List[Dataset]) – A list of training datasets to check.

  • validation_datasets (List[Dataset]) – A list of validation datasets to check

Raises:
  • TypeError – If the dtype within the datasets are inconsistent.

  • ValueError – If the validation_datasets has a target that is not present in the train_datasets.

  • ValueError – If the training or validation set contains chemical species or targets that are not present in the training set

metatensor.models.utils.data.dataset.group_and_join(batch: List[Dict[str, Any]]) Dict[str, Any][source]

Same as metatenor.learn.data.group_and_join, but joins dicts and not named tuples.

Parameters:

batch (List[Dict[str, Any]]) – A list of dictionaries, each containing the data for a single sample.

Returns:

A single dictionary with the data fields joined together among all samples.

Return type:

Dict[str, Any]