Skip to content

Commit

Permalink
fix: Potential load method
Browse files Browse the repository at this point in the history
  • Loading branch information
Xixian committed Dec 16, 2024
1 parent 92af875 commit 65c8bd6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/user_guide/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ a list of structures using the ``Potential`` class.
# load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running MatterSim on {device}")
potential = Potential.load(device=device)
potential = Potential.from_checkpoint(device=device)
# build the dataloader that is compatible with MatterSim
dataloader = build_dataloader(structures, only_inference=True)
Expand Down
28 changes: 20 additions & 8 deletions src/mattersim/forcefield/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,27 +965,39 @@ def load(
if model_name.lower() != "m3gnet":
raise NotImplementedError

current_dir = os.path.dirname(__file__)
checkpoint_folder = os.path.expanduser("~/.local/mattersim/pretrained_models")
os.makedirs(checkpoint_folder, exist_ok=True)
if (
load_path is None
or load_path.lower() == "mattersim-v1.0.0-1m.pth"
or load_path.lower() == "mattersim-v1.0.0-1m"
):
load_path = os.path.join(
current_dir, "..", "pretrained_models/mattersim-v1.0.0-1M.pth"
)
load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-1M.pth")
if not os.path.exists(load_path):
logger.info(
"The pre-trained model is not found locally, "
"attempting to download it from the server."
)
download_checkpoint(
"mattersim-v1.0.0-1M.pth", save_folder=checkpoint_folder
)
logger.info(f"Loading the pre-trained {os.path.basename(load_path)} model")
elif (
load_path.lower() == "mattersim-v1.0.0-5m.pth"
or load_path.lower() == "mattersim-v1.0.0-5m"
):
load_path = os.path.join(
current_dir, "..", "pretrained_models/mattersim-v1.0.0-5M.pth"
)
load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-5M.pth")
if not os.path.exists(load_path):
logger.info(
"The pre-trained model is not found locally, "
"attempting to download it from the server."
)
download_checkpoint(
"mattersim-v1.0.0-5M.pth", save_folder=checkpoint_folder
)
logger.info(f"Loading the pre-trained {os.path.basename(load_path)} model")
else:
logger.info("Loading the model from %s" % load_path)

assert os.path.exists(load_path), f"Model file {load_path} not found"

checkpoint = torch.load(load_path, map_location=device)
Expand Down

0 comments on commit 65c8bd6

Please sign in to comment.