Skip to content

Commit

Permalink
Merge pull request #137 from pesser/test_cli_interface
Browse files Browse the repository at this point in the history
cli interface test for -p and eval_all=True and eval_forever=True
  • Loading branch information
pesser authored Aug 8, 2019
2 parents a31fa87 + 5a655a7 commit 596893e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 15 deletions.
39 changes: 27 additions & 12 deletions edflow/edflow
Original file line number Diff line number Diff line change
Expand Up @@ -115,28 +115,43 @@ def main(opt, additional_kwargs):
# Evaluation
opt.eval = opt.eval or list()
for eval_idx, eval_config in enumerate(opt.eval):
if opt.checkpoint is not None:
checkpoint = opt.checkpoint
elif opt.project is not None:
checkpoint = get_latest_checkpoint(P.checkpoints)
else:
checkpoint = None

base_config = {}
if opt.base is not None:
for base in opt.base:
with open(base) as f:
base_config.update(yaml.full_load(f))
# get path to implementation
with open(eval_config) as f:
config = base_config
config.update(yaml.full_load(f))
update_config(config, additional_kwargs)

disable_eval_all = disable_eval_forever = False
if opt.checkpoint is not None:
for k in ["eval_forever", "eval_all"]:
if k in config:
config[k] = False
logger.info("{} was disabled because you specified a checkpoint.".format(k))
checkpoint = opt.checkpoint
disable_eval_all = disable_eval_forever = True
elif opt.project is not None:
if any([config.get("eval_all", False), config.get("eval_forever", False)]):
checkpoint = None
else:
checkpoint = get_latest_checkpoint(P.checkpoints)
disable_eval_all = disable_eval_forever = True
else:
checkpoint = None

if disable_eval_all:
config.update({"eval_all": False})
logger.info(
"{} was disabled because you specified a checkpoint.".format("eval_all")
)

if disable_eval_forever:
config.update({"eval_forever": False})
logger.info(
"{} was disabled because you specified a checkpoint.".format(
"eval_forever"
)
)

logger.info(
"Evaluation config: {}\n{}".format(eval_config, yaml.dump(eval_config))
)
Expand Down
56 changes: 53 additions & 3 deletions tests/test_edflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,11 @@ def test_3(self, tmpdir):
assert any(list(filter(lambda x: "test_inference" in x, eval_dirs)))

def test_4(self, tmpdir):
"""
Tests evaluation with providing a checkpoint and using eval_all=True and eval_forever=True.
This should load not load any checkpoint.
"""Tests evaluation with
1. providing a checkpoint
2. and using eval_all=True and eval_forever=True.
This should disable overwrite eval_all and eval_forever to ``False``, and then load the specified checkpoint
effectively runs
edflow -e config.yaml -b config.yaml -c logs/trained_model/train/checkpoints/model.ckpt-0
Expand Down Expand Up @@ -300,3 +302,51 @@ def test_4(self, tmpdir):
# check if correct folder was created
eval_dirs = os.listdir(os.path.join(tmpdir, "logs", "trained_model", "eval"))
assert any(list(filter(lambda x: "test_inference" in x, eval_dirs)))

def test_5(self, tmpdir):
"""Tests evaluation with
1. providing a project
1. using eval_all=True and eval_forever=True.
This should NOT load any checkpoint.
effectively runs
edflow -e config.yaml -b config.yaml -p logs/trained_model -n test_inference
and then checks if an evaluation folder "test_inference" was created in logs/trained_model/eval
-------
"""
self.setup_tmpdir(tmpdir)
# command = "edflow -e eval.yaml -b train.yaml -n test"
config = dict()
config["model"] = "tests." + fullname(Model)
config["iterator"] = "tests." + fullname(Iterator_no_checkpoint)
config["dataset"] = "tests." + fullname(Dataset)
config["batch_size"] = 16
config["num_steps"] = 100
config["eval_all"] = True
config["eval_forever"] = False
import yaml

with open(os.path.join(tmpdir, "config.yaml"), "w") as outfile:
yaml.dump(config, outfile, default_flow_style=False)
import shutil

shutil.copytree(os.path.split(__file__)[0], os.path.join(tmpdir, "tests"))
command = [
"edflow",
"-e",
"config.yaml",
"-p",
os.path.join("logs", "trained_model"),
"-b",
"config.yaml",
"-n",
"test_inference",
]
command = " ".join(command)
run_edflow_cmdline(command, cwd=tmpdir)

# check if correct folder was created
eval_dirs = os.listdir(os.path.join(tmpdir, "logs", "trained_model", "eval"))
assert any(list(filter(lambda x: "test_inference" in x, eval_dirs)))

0 comments on commit 596893e

Please sign in to comment.