Dataset¶
- class metatensor.models.utils.data.dataset.Dataset(dict: Dict)[source]¶
Bases:
object
A version of the metatensor.learn.Dataset class that allows for the use of mtm:: prefixes in the keys of the dictionary. See https://github.com/lab-cosmo/metatensor/issues/621.
It is important to note that, instead of named tuples, this class accepts and returns dictionaries.
- Parameters:
dict (Dict) – A dictionary with the data to be stored in the dataset.
- class metatensor.models.utils.data.dataset.TargetInfo(quantity: str, unit: str = '', per_atom: bool = False, gradients: ~typing.List[str] = <factory>)[source]¶
Bases:
object
A class that contains information about a target.
- Parameters:
- class metatensor.models.utils.data.dataset.DatasetInfo(length_unit: str, targets: Dict[str, TargetInfo])[source]¶
Bases:
object
A class that contains information about one or more datasets.
This dataclass is used to communicate additional dataset details to the training functions of the individual models.
- Parameters:
length_unit (str) – The unit of length used in the dataset.
targets (Dict[str, TargetInfo]) – The information about targets in the dataset.
- targets: Dict[str, TargetInfo]¶
- metatensor.models.utils.data.dataset.get_all_species(datasets: Dataset | List[Dataset]) List[int] [source]¶
Returns the list of all species present in a dataset or list of datasets.
- metatensor.models.utils.data.dataset.get_all_targets(datasets: Dataset | List[Dataset]) List[str] [source]¶
Returns the list of all targets present in a dataset or list of datasets.
- metatensor.models.utils.data.dataset.collate_fn(batch: List[Dict[str, Any]]) Tuple[List, Dict[str, TensorMap]] [source]¶
Wraps group_and_join to return the data fields as a list of systems, and a dictionary of nameed targets.
- metatensor.models.utils.data.dataset.check_datasets(train_datasets: List[Dataset], validation_datasets: List[Dataset])[source]¶
Check that the training and validation sets are compatible with one another
Although these checks will not fit all use cases, most models would be expected to be able to use this function.
- Parameters:
- Raises:
TypeError – If the
dtype
within the datasets are inconsistent.ValueError – If the validation_datasets has a target that is not present in the
train_datasets
.ValueError – If the training or validation set contains chemical species or targets that are not present in the training set