importargparseimportloggingfrompathlibimportPathfromtypingimportAny,Unionfrommetatensor.torch.atomisticimportis_atomistic_modelfrom..utils.ioimportcheck_file_extension,load_modelfrom.formatterimportCustomHelpFormatterlogger=logging.getLogger(__name__)def_add_export_model_parser(subparser:argparse._SubParsersAction)->None:"""Add `export_model` paramaters to an argparse (sub)-parser."""ifexport_model.__doc__isnotNone:description=export_model.__doc__.split(r":param")[0]else:description=None# If you change the synopsis of these commands or add new ones adjust the completion# script at `src/metatrain/share/metatrain-completion.bash`.parser=subparser.add_parser("export",description=description,formatter_class=CustomHelpFormatter,)parser.set_defaults(callable="export_model")parser.add_argument("path",type=str,help=("Saved model which should be exported. Path can be either a URL or a ""local file."),)parser.add_argument("-o","--output",dest="output",type=str,required=False,help=("Filename of the exported model (default: <stem>.pt, ""where <stem> is the name of the checkpoint without the extension)."),)parser.add_argument("--huggingface_api_token",dest="huggingface_api_token",type=str,required=False,default="",help="API token to download a private model from HuggingFace.",)def_prepare_export_model_args(args:argparse.Namespace)->None:"""Prepare arguments for export_model."""path=args.__dict__.pop("path")args.model=load_model(path=path,**args.__dict__,)keys_to_keep=["model","output"]# only these are needed for `export_model``original_keys=list(args.__dict__.keys())forkeyinoriginal_keys:ifkeynotinkeys_to_keep:args.__dict__.pop(key)ifargs.__dict__.get("output")isNone:args.__dict__["output"]=Path(path).stem+".pt"
[docs]defexport_model(model:Any,output:Union[Path,str])->None:"""Export a trained model allowing it to make predictions. This includes predictions within molecular simulation engines. Exported models will be saved with a ``.pt`` file ending. If ``path`` does not end with this file extensions ``.pt`` will be added and a warning emitted. :param model: model to be exported :param output: path to save the model """path=str(Path(check_file_extension(filename=output,extension=".pt")).absolute().resolve())extensions_path=str(Path("extensions/").absolute().resolve())ifnotis_atomistic_model(model):model=model.export()model.save(path,collect_extensions=extensions_path)logger.info(f"Model exported to '{path}' and extensions to '{extensions_path}'")