Skip to content

Commit

Permalink
Add error for not finding models; change to 'get_latest_model_version'
Browse files Browse the repository at this point in the history
  • Loading branch information
priyappillai committed May 14, 2021
1 parent 886ffde commit cb6758d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
6 changes: 3 additions & 3 deletions adapt/utils/tests/test_predict_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from adapt import alignment
from adapt.utils import predict_activity
from adapt.utils.version import get_project_path, get_model_version
from adapt.utils.version import get_project_path, get_latest_model_version

__author__ = 'Hayden Metsky <hayden@mit.edu>'

Expand All @@ -22,8 +22,8 @@ def setUp(self):
dir_path = get_project_path()
cla_path_all = os.path.join(dir_path, 'models', 'classify', 'cas13a')
reg_path_all = os.path.join(dir_path, 'models', 'regress', 'cas13a')
cla_version = get_model_version(cla_path_all)
reg_version = get_model_version(reg_path_all)
cla_version = get_latest_model_version(cla_path_all)
reg_version = get_latest_model_version(reg_path_all)
cla_path = os.path.join(cla_path_all, cla_version)
reg_path = os.path.join(reg_path_all, reg_version)

Expand Down
25 changes: 21 additions & 4 deletions adapt/utils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,33 @@ def get_version():
return __version__


def get_model_version(model_path):
def get_latest_model_version(model_path):
"""Get latest model version, given the model path
"""
# List all model versions in path
model_versions = os.listdir(model_path)
# Get a list of the versions
# Each version is represented as a list of numbers
model_versions_numeric = [[int(i) for i in model_version[1:].split('_')]
for model_version in model_versions if
model_version.startswith('v')]
model_versions_numeric = []
for model_version in model_versions:
if model_version.startswith('v'):
model_version_numeric = []
skip = False
for i in model_version[1:].split('_'):
if not i.isdecimal():
skip = True
break
else:
model_version_numeric.append(int(i))
if not skip and len(model_version_numeric) > 0:
model_versions_numeric.append(model_version_numeric)

# If there were no models found on the path, raise an error
if len(model_versions_numeric) == 0:
raise ValueError("There are no appropriately formatted models in the "
"model path. Please make sure the models are in a folder with the "
"format 'v_#_#'")

# Remake the version string
latest_version = [str(i) for i in sorted(model_versions_numeric)[-1]]
return 'v' + '_'.join(latest_version)
Expand Down
6 changes: 3 additions & 3 deletions bin/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from adapt.utils import predict_activity
from adapt.utils import seq_io
from adapt.utils import year_cover
from adapt.utils.version import get_project_path, get_model_version
from adapt.utils.version import get_project_path, get_latest_model_version

try:
import boto3
Expand Down Expand Up @@ -649,12 +649,12 @@ def guide_is_suitable(guide):
"the classifier and the regressor must be set."))
if (len(args.predict_cas13a_activity_model) == 0 or
args.predict_cas13a_activity_model[0] == 'latest'):
cla_version = get_model_version(cla_path_all)
cla_version = get_latest_model_version(cla_path_all)
else:
cla_version = args.predict_cas13a_activity_model[0]
if (len(args.predict_cas13a_activity_model) == 0 or
args.predict_cas13a_activity_model[1] == 'latest'):
reg_version = get_model_version(reg_path_all)
reg_version = get_latest_model_version(reg_path_all)
else:
reg_version = args.predict_cas13a_activity_model[1]
cla_path = os.path.join(cla_path_all, cla_version)
Expand Down

0 comments on commit cb6758d

Please sign in to comment.