From 623be1bc8b4b39b38405782219d2a8caae68c635 Mon Sep 17 00:00:00 2001 From: supermario94123 Date: Thu, 8 Aug 2019 15:54:00 +0200 Subject: [PATCH 1/2] cli interface test for -p and eval_all=True and eval_forever=True --- tests/test_edflow.py | 56 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/tests/test_edflow.py b/tests/test_edflow.py index 987beef..9f681d0 100644 --- a/tests/test_edflow.py +++ b/tests/test_edflow.py @@ -251,9 +251,11 @@ def test_3(self, tmpdir): assert any(list(filter(lambda x: "test_inference" in x, eval_dirs))) def test_4(self, tmpdir): - """ - Tests evaluation with providing a checkpoint and using eval_all=True and eval_forever=True. - This should load not load any checkpoint. + """Tests evaluation with + 1. providing a checkpoint + 2. and using eval_all=True and eval_forever=True. + + This should disable overwrite eval_all and eval_forever to ``False``, and then load the specified checkpoint effectively runs edflow -e config.yaml -b config.yaml -c logs/trained_model/train/checkpoints/model.ckpt-0 @@ -300,3 +302,51 @@ def test_4(self, tmpdir): # check if correct folder was created eval_dirs = os.listdir(os.path.join(tmpdir, "logs", "trained_model", "eval")) assert any(list(filter(lambda x: "test_inference" in x, eval_dirs))) + + def test_5(self, tmpdir): + """Tests evaluation with + 1. providing a project + 1. using eval_all=True and eval_forever=True. + + This should NOT load any checkpoint. + + effectively runs + edflow -e config.yaml -b config.yaml -p logs/trained_model -n test_inference + + and then checks if an evaluation folder "test_inference" was created in logs/trained_model/eval + ------- + """ + self.setup_tmpdir(tmpdir) + # command = "edflow -e eval.yaml -b train.yaml -n test" + config = dict() + config["model"] = "tests." + fullname(Model) + config["iterator"] = "tests." + fullname(Iterator_no_checkpoint) + config["dataset"] = "tests." + fullname(Dataset) + config["batch_size"] = 16 + config["num_steps"] = 100 + config["eval_all"] = True + config["eval_forever"] = False + import yaml + + with open(os.path.join(tmpdir, "config.yaml"), "w") as outfile: + yaml.dump(config, outfile, default_flow_style=False) + import shutil + + shutil.copytree(os.path.split(__file__)[0], os.path.join(tmpdir, "tests")) + command = [ + "edflow", + "-e", + "config.yaml", + "-p", + os.path.join("logs", "trained_model"), + "-b", + "config.yaml", + "-n", + "test_inference", + ] + command = " ".join(command) + run_edflow_cmdline(command, cwd=tmpdir) + + # check if correct folder was created + eval_dirs = os.listdir(os.path.join(tmpdir, "logs", "trained_model", "eval")) + assert any(list(filter(lambda x: "test_inference" in x, eval_dirs))) From aa569d3745dec9bc36ae222cf70a7acbd159e3b3 Mon Sep 17 00:00:00 2001 From: supermario94123 Date: Thu, 8 Aug 2019 16:31:50 +0200 Subject: [PATCH 2/2] fixes -p with eval_all=True and eval_forever=True --- edflow/edflow | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/edflow/edflow b/edflow/edflow index cb1604b..b296a24 100644 --- a/edflow/edflow +++ b/edflow/edflow @@ -115,28 +115,43 @@ def main(opt, additional_kwargs): # Evaluation opt.eval = opt.eval or list() for eval_idx, eval_config in enumerate(opt.eval): - if opt.checkpoint is not None: - checkpoint = opt.checkpoint - elif opt.project is not None: - checkpoint = get_latest_checkpoint(P.checkpoints) - else: - checkpoint = None - base_config = {} if opt.base is not None: for base in opt.base: with open(base) as f: base_config.update(yaml.full_load(f)) - # get path to implementation with open(eval_config) as f: config = base_config config.update(yaml.full_load(f)) update_config(config, additional_kwargs) + + disable_eval_all = disable_eval_forever = False if opt.checkpoint is not None: - for k in ["eval_forever", "eval_all"]: - if k in config: - config[k] = False - logger.info("{} was disabled because you specified a checkpoint.".format(k)) + checkpoint = opt.checkpoint + disable_eval_all = disable_eval_forever = True + elif opt.project is not None: + if any([config.get("eval_all", False), config.get("eval_forever", False)]): + checkpoint = None + else: + checkpoint = get_latest_checkpoint(P.checkpoints) + disable_eval_all = disable_eval_forever = True + else: + checkpoint = None + + if disable_eval_all: + config.update({"eval_all": False}) + logger.info( + "{} was disabled because you specified a checkpoint.".format("eval_all") + ) + + if disable_eval_forever: + config.update({"eval_forever": False}) + logger.info( + "{} was disabled because you specified a checkpoint.".format( + "eval_forever" + ) + ) + logger.info( "Evaluation config: {}\n{}".format(eval_config, yaml.dump(eval_config)) )