Composition model¶
- class metatrain.utils.additive.composition.CompositionModel(model_hypers: Dict, dataset_info: DatasetInfo)[source]¶
Bases:
Module
A simple model that calculates the contributions to scalar targets based on the stoichiometry in a system.
- Parameters:
model_hypers (Dict) – A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API.
dataset_info (DatasetInfo) – An object containing information about the dataset, including target quantities and atomic types.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- outputs: Dict[str, ModelOutput]¶
- train_model(datasets: List[Dataset | Subset], fixed_weights: Dict[str, Dict[int, str]] | None = None) None [source]¶
Train/fit the composition weights for the datasets.
- Parameters:
- Raises:
ValueError – If the provided datasets contain unknown targets.
ValueError – If the provided datasets contain unknown atomic types.
RuntimeError – If the linear system to calculate the composition weights cannot be solved.
- Return type:
None
- restart(dataset_info: DatasetInfo) CompositionModel [source]¶
- Parameters:
dataset_info (DatasetInfo)
- Return type:
- forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap] [source]¶
Compute the targets for each system based on the composition weights.
- Parameters:
- Returns:
A dictionary with the computed predictions for each system.
- Raises:
ValueError – If no weights have been computed or if outputs keys contain unsupported keys.
- Return type: