From 65c8bd6081d33e01ce874c62d7591f9dfa3dd484 Mon Sep 17 00:00:00 2001 From: Xixian Date: Mon, 16 Dec 2024 23:20:30 +0800 Subject: [PATCH] fix: Potential load method --- docs/user_guide/getting_started.rst | 2 +- src/mattersim/forcefield/potential.py | 28 +++++++++++++++++++-------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/docs/user_guide/getting_started.rst b/docs/user_guide/getting_started.rst index 3703b26..56ca226 100644 --- a/docs/user_guide/getting_started.rst +++ b/docs/user_guide/getting_started.rst @@ -59,7 +59,7 @@ a list of structures using the ``Potential`` class. # load the model device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Running MatterSim on {device}") - potential = Potential.load(device=device) + potential = Potential.from_checkpoint(device=device) # build the dataloader that is compatible with MatterSim dataloader = build_dataloader(structures, only_inference=True) diff --git a/src/mattersim/forcefield/potential.py b/src/mattersim/forcefield/potential.py index d95fff3..1225690 100644 --- a/src/mattersim/forcefield/potential.py +++ b/src/mattersim/forcefield/potential.py @@ -965,27 +965,39 @@ def load( if model_name.lower() != "m3gnet": raise NotImplementedError - current_dir = os.path.dirname(__file__) + checkpoint_folder = os.path.expanduser("~/.local/mattersim/pretrained_models") + os.makedirs(checkpoint_folder, exist_ok=True) if ( load_path is None or load_path.lower() == "mattersim-v1.0.0-1m.pth" or load_path.lower() == "mattersim-v1.0.0-1m" ): - load_path = os.path.join( - current_dir, "..", "pretrained_models/mattersim-v1.0.0-1M.pth" - ) + load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-1M.pth") + if not os.path.exists(load_path): + logger.info( + "The pre-trained model is not found locally, " + "attempting to download it from the server." + ) + download_checkpoint( + "mattersim-v1.0.0-1M.pth", save_folder=checkpoint_folder + ) logger.info(f"Loading the pre-trained {os.path.basename(load_path)} model") elif ( load_path.lower() == "mattersim-v1.0.0-5m.pth" or load_path.lower() == "mattersim-v1.0.0-5m" ): - load_path = os.path.join( - current_dir, "..", "pretrained_models/mattersim-v1.0.0-5M.pth" - ) + load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-5M.pth") + if not os.path.exists(load_path): + logger.info( + "The pre-trained model is not found locally, " + "attempting to download it from the server." + ) + download_checkpoint( + "mattersim-v1.0.0-5M.pth", save_folder=checkpoint_folder + ) logger.info(f"Loading the pre-trained {os.path.basename(load_path)} model") else: logger.info("Loading the model from %s" % load_path) - assert os.path.exists(load_path), f"Model file {load_path} not found" checkpoint = torch.load(load_path, map_location=device)