diff --git a/.gitignore b/.gitignore index 14b061c..435d5df 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,8 @@ # Playing around stuff **play** -**logs** -**log** +logs/ +log/ # Pictures **.png @@ -27,3 +27,6 @@ docs/build/* docs/_* .nfs* + +# Generated output +examples/data diff --git a/.travis.yml b/.travis.yml index 85a937b..9c3d1fe 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,5 @@ language: python python: -# - 2.7 - 3.6 notifications: email: false @@ -18,18 +17,35 @@ before_install: - conda update -q conda - conda info -a -install: - - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION tensorflow=1.13.1 - - source activate test-environment - - conda install pytorch-cpu torchvision-cpu -c pytorch - - pip install tensorboardX - - python setup.py install +install: true # Run test jobs: include: - - stage: example_tests + - stage: formatting + name: "formatting test" + install: + - pip install black + script: + - black --check . + - stage: tf_example_tests + name: "Tensorflow examples test" + install: + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION tensorflow=1.13.1 + - source activate test-environment + - python setup.py install + script: + - cd examples + - python -m pytest -k tf + - stage: torch_example_tests + name: "Pytorch examples test" + install: + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION + - source activate test-environment + - conda install pytorch-cpu torchvision-cpu -c pytorch + - pip install tensorboardX + - python setup.py install script: - cd examples - - python -m pytest + - python -m pytest -k torch diff --git a/edflow/applications/__init__.py b/edflow/applications/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/edflow/iterators/deeploss.py b/edflow/applications/tf_perceptual_loss.py similarity index 100% rename from edflow/iterators/deeploss.py rename to edflow/applications/tf_perceptual_loss.py diff --git a/edflow/edflow b/edflow/edflow index 68c2690..aa5fb10 100644 --- a/edflow/edflow +++ b/edflow/edflow @@ -23,7 +23,7 @@ import yaml # noqa from edflow.main import train, test # noqa from edflow.custom_logging import init_project, use_project, get_logger # noqa from edflow.custom_logging import set_global_stdout_level # noqa -from edflow.hooks.evaluation_hooks import get_latest_checkpoint # noqa +from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint # noqa def update_config(config, options): diff --git a/edflow/evaluate.py b/edflow/evaluate.py deleted file mode 100644 index 1b74247..0000000 --- a/edflow/evaluate.py +++ /dev/null @@ -1,39 +0,0 @@ -import tensorflow as tf - -import argparse -import numpy as np -import glob -import os -import yaml -from tqdm import tqdm, trange - -import multiprocessing as mp - -from edflow.custom_logging import use_project, get_logger -from edflow.main import test - - -def main(opt): - with open(opt.config) as f: - config = yaml.load(f) - - P = use_project(opt.project) - logger = get_logger("main_evaluate", "latest_eval") - logger.info(opt) - logger.info(P) - logger.info(yaml.dump(config)) - - test(config, P.latest_eval) - logger.info("Finished") - print("\n" * 5) - - -if __name__ == "__main__": - default_log_dir = os.path.join(os.getcwd(), "log") - - parser = argparse.ArgumentParser() - parser.add_argument("--config", required=True, help="path to config") - parser.add_argument("--project", help="path to project root") - - opt = parser.parse_args() - main(opt) diff --git a/edflow/hooks/__init__.py b/edflow/hooks/__init__.py index ebc0ae2..e69de29 100644 --- a/edflow/hooks/__init__.py +++ b/edflow/hooks/__init__.py @@ -1,3 +0,0 @@ -from edflow.hooks.train_hooks import * -from edflow.hooks.evaluation_hooks import * -from edflow.hooks.hook import Hook, match_frequency diff --git a/edflow/hooks/checkpoint_hooks/__init__.py b/edflow/hooks/checkpoint_hooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/edflow/hooks/evaluation_hooks.py b/edflow/hooks/checkpoint_hooks/common.py similarity index 63% rename from edflow/hooks/evaluation_hooks.py rename to edflow/hooks/checkpoint_hooks/common.py index 84e3382..7b66955 100644 --- a/edflow/hooks/evaluation_hooks.py +++ b/edflow/hooks/checkpoint_hooks/common.py @@ -1,9 +1,3 @@ -import tensorflow as tf - -try: - import torch -except ImportError: - print("Warning: Could not import torch.") import time import os import re @@ -23,71 +17,6 @@ P = ProjectManager() -class WaitForCheckpointHook(Hook): - """Waits until a new checkpoint is created, then lets the Iterator - continue.""" - - def __init__( - self, - checkpoint_root, - filter_cond=lambda c: True, - interval=5, - add_sec=5, - callback=None, - eval_all=False, - ): - """Args: - checkpoint_root (str): Path to look for checkpoints. - filter_cond (Callable): A function used to filter files, to only - get the checkpoints that are wanted. - interval (float): Number of seconds after which to check for a new - checkpoint again. - add_sec (float): Number of seconds to wait, after a checkpoint is - found, to avoid race conditions, if the checkpoint is still - being written at the time it's meant to be read. - callback (Callable): Callback called with path of found - checkpoint. - eval_all (bool): Accept all instead of just latest checkpoint. - """ - - self.root = checkpoint_root - self._fcond = filter_cond - self.sleep_interval = interval - self.additional_wait = add_sec - self.callback = callback - self.eval_all = eval_all - - self.logger = get_logger(self) - - self.known_checkpoints = set() - - def fcond(self, c): - cond = self._fcond(c) - if self.eval_all: - cond = cond and c not in self.known_checkpoints - return cond - - def look(self): - """Loop until a new checkpoint is found.""" - self.logger.info("Waiting for new checkpoint.") - while True: - latest_checkpoint = get_latest_checkpoint(self.root, self.fcond) - if ( - latest_checkpoint is not None - and latest_checkpoint not in self.known_checkpoints - ): - self.known_checkpoints.add(latest_checkpoint) - time.sleep(self.additional_wait) - self.logger.info("Found new checkpoint: {}".format(latest_checkpoint)) - if self.callback is not None: - self.callback(latest_checkpoint) - break - time.sleep(self.sleep_interval) - - def before_epoch(self, ep): - self.look() - - def get_latest_checkpoint(checkpoint_root, filter_cond=lambda c: True): """Return path to name of latest checkpoint in checkpoint_root dir. @@ -143,123 +72,69 @@ def get_latest_checkpoint(checkpoint_root, filter_cond=lambda c: True): return latest -class RestoreModelHook(Hook): - """Restores a TensorFlow model from a checkpoint at each epoch. Can also - be used as a functor.""" +class WaitForCheckpointHook(Hook): + """Waits until a new checkpoint is created, then lets the Iterator + continue.""" def __init__( self, - variables, - checkpoint_path, + checkpoint_root, filter_cond=lambda c: True, - global_step_setter=None, + interval=5, + add_sec=5, + callback=None, + eval_all=False, ): """Args: - variables (list): tf.Variable to be loaded from the checkpoint. - checkpoint_path (str): Directory in which the checkpoints are - stored or explicit checkpoint. Ignored if used as functor. + checkpoint_root (str): Path to look for checkpoints. filter_cond (Callable): A function used to filter files, to only - get the checkpoints that are wanted. Ignored if used as - functor. - global_step_setter (Callable): Callback to set global_step. + get the checkpoints that are wanted. + interval (float): Number of seconds after which to check for a new + checkpoint again. + add_sec (float): Number of seconds to wait, after a checkpoint is + found, to avoid race conditions, if the checkpoint is still + being written at the time it's meant to be read. + callback (Callable): Callback called with path of found + checkpoint. + eval_all (bool): Accept all instead of just latest checkpoint. """ - self.root = checkpoint_path - self.fcond = filter_cond - self.setstep = global_step_setter - - self.logger = get_logger(self) - - self.saver = tf.train.Saver(variables) - - @property - def session(self): - if not hasattr(self, "_session"): - self._session = tf.get_default_session() - return self._session - - def before_epoch(self, ep): - # checkpoint = tf.train.latest_checkpoint(self.root) - checkpoint = get_latest_checkpoint(self.root, self.fcond) - self(checkpoint) - - def __call__(self, checkpoint): - self.saver.restore(self.session, checkpoint) - self.logger.info("Restored model from {}".format(checkpoint)) - global_step = self.parse_global_step(checkpoint) - self.logger.info("Global step: {}".format(global_step)) - if self.setstep is not None: - self.setstep(global_step) - - @staticmethod - def parse_global_step(checkpoint): - global_step = int(checkpoint.rsplit("-", 1)[1]) - return global_step - - -# Simple renaming for consistency -# Todo: Make the Restore op part of the model (issue #2) -# https://bitbucket.org/jhaux/edflow/issues/2/make-a-general-model-restore-hook -RestoreTFModelHook = RestoreModelHook - -class RestorePytorchModelHook(Hook): - """Restores a PyTorch model from a checkpoint at each epoch. Can also be - used as a functor.""" - - def __init__( - self, - model, - checkpoint_path, - filter_cond=lambda c: True, - global_step_setter=None, - ): - """Args: - model (torch.nn.Module): Model to initialize - checkpoint_path (str): Directory in which the checkpoints are - stored or explicit checkpoint. Ignored if used as functor. - filter_cond (Callable): A function used to filter files, to only - get the checkpoints that are wanted. Ignored if used as - functor. - global_step_setter (Callable): Function, that the retrieved global - step can be passed to. - """ - self.root = checkpoint_path - self.fcond = filter_cond + self.root = checkpoint_root + self._fcond = filter_cond + self.sleep_interval = interval + self.additional_wait = add_sec + self.callback = callback + self.eval_all = eval_all self.logger = get_logger(self) - self.model = model - self.global_step_setter = global_step_setter - - def before_epoch(self, ep): - checkpoint = get_latest_checkpoint(self.root, self.fcond) - self(checkpoint) - - def __call__(self, checkpoint): - self.model.load_state_dict(torch.load(checkpoint)) - self.logger.info("Restored model from {}".format(checkpoint)) - - epoch, step = self.parse_checkpoint(checkpoint) - - if self.global_step_setter is not None: - self.global_step_setter(step) - self.logger.info("Epoch: {}, Global step: {}".format(epoch, step)) + self.known_checkpoints = set() - @staticmethod - def parse_global_step(checkpoint): - return RestorePytorchModelHook.parse_checkpoint(checkpoint)[1] + def fcond(self, c): + cond = self._fcond(c) + if self.eval_all: + cond = cond and c not in self.known_checkpoints + return cond - @staticmethod - def parse_checkpoint(checkpoint): - e_s = os.path.basename(checkpoint).split(".")[0].split("-") - if len(e_s) > 1: - epoch = e_s[0] - step = e_s[1].split("_")[0] - else: - epoch = 0 - step = e_s[0].split("_")[0] + def look(self): + """Loop until a new checkpoint is found.""" + self.logger.info("Waiting for new checkpoint.") + while True: + latest_checkpoint = get_latest_checkpoint(self.root, self.fcond) + if ( + latest_checkpoint is not None + and latest_checkpoint not in self.known_checkpoints + ): + self.known_checkpoints.add(latest_checkpoint) + time.sleep(self.additional_wait) + self.logger.info("Found new checkpoint: {}".format(latest_checkpoint)) + if self.callback is not None: + self.callback(latest_checkpoint) + break + time.sleep(self.sleep_interval) - return int(epoch), int(step) + def before_epoch(self, ep): + self.look() def strenumerate(*args, **kwargs): @@ -412,91 +287,22 @@ def test_valid_metrictuple(metric_tuple): # enough checking already :) -class MetricHook(Hook): - """Applies a set of given metrics to the calculated data.""" - - def __init__(self, metrics, save_root, consider_only_first=None): - """Args: - metrics (list): List of ``MetricTuple``s of the form - ``(input names, output names, metric, name)``. - - ``input names`` are the keys corresponding to the feeds of - interest, e.g. an original image. - - ``output names`` are the keys corresponding to the values - in the results dict. - - ``metric`` is a ``Callable`` that accepts all inputs and - outputs keys as keyword arguments - - ``name`` is a - If nested feeds or results are expected the names can be - passed as "path" like ``'key1_key2'`` returning - ``dict[key1][key2]``. - save_root (str): Path to where the results are stored. - consider_only_first (int): Metric is only evaluated on the first - `consider_only_first` examples. - """ - - self.metrics = metrics - - self.root = save_root - self.logger = get_logger(self, "latest_eval") - - self.max_step = consider_only_first - - self.storage_dict = {} - self.metric_results = {} - for m in metrics: - test_valid_metrictuple(m) - - self.tb_saver = tf.summary.FileWriter(self.root) - - def before_epoch(self, epoch): - self.count = 0 - for m in self.metrics: - self.metric_results[m.name] = [] - - def before_step(self, step, fetches, feeds, batch): - if self.max_step is not None and self.count >= self.max_step: - return - - for in_names, out_names, metric, m_name in self.metrics: - self.storage_dict[m_name] = {} - for kwargs_name, name in in_names.items(): - val = retrieve(name, batch) - self.storage_dict[m_name][kwargs_name] = val - - def after_step(self, step, results): - if self.max_step is not None and self.count >= self.max_step: - return +def torch_parse_global_step(checkpoint): + e_s = os.path.basename(checkpoint).split(".")[0].split("-") + if len(e_s) > 1: + epoch = e_s[0] + step = e_s[1].split("_")[0] + else: + epoch = 0 + step = e_s[0].split("_")[0] - for in_names, out_names, metric, m_name in self.metrics: - for kwargs_name, name in out_names.items(): - val = retrieve(name, results) - self.storage_dict[m_name][kwargs_name] = val - m_res = metric(**self.storage_dict[m_name]) - self.metric_results[m_name] += [m_res] + epoch, step = int(epoch), int(step) + return step - self.global_step = results["global_step"] - self.count += 1 - def after_epoch(self, epoch): - self.logger.info("Metrics at epoch {}:".format(epoch)) - - mean_results = {} - for name, result in self.metric_results.items(): - results = np.concatenate(result) - mean = np.mean(results, axis=0) - var = np.std(results, axis=0) - mean_results[name] = np.array([mean, var]) - self.logger.info("{}: {} +- {}".format(name, mean, var)) - - summary = tf.Summary() - summary_mean = mean if len(mean.shape) == 0 else mean[0] - summary.value.add(tag=name, simple_value=summary_mean) - self.tb_saver.add_summary(summary, self.global_step) - self.tb_saver.flush() - - name = "{:0>6d}_metrics".format(self.global_step) - name = os.path.join(self.root, name) - np.savez_compressed(name, **mean_results) +def tf_parse_global_step(checkpoint): + global_step = int(checkpoint.rsplit("-", 1)[1]) + return global_step def get_checkpoint_files(checkpoint_root): @@ -516,10 +322,10 @@ def get_checkpoint_files(checkpoint_root): name, ext = os.path.splitext(p) if not ext == ".ckpt": normalized = name - global_step = RestoreTFModelHook.parse_global_step(normalized) + global_step = tf_parse_global_step(normalized) else: normalized = p - global_step = RestorePytorchModelHook.parse_global_step(normalized) + global_step = torch_parse_global_step(normalized) files.append(p) checkpoints.append(normalized) global_steps.append(global_step) diff --git a/edflow/hooks/train_hooks.py b/edflow/hooks/checkpoint_hooks/tf_checkpoint_hook.py similarity index 63% rename from edflow/hooks/train_hooks.py rename to edflow/hooks/checkpoint_hooks/tf_checkpoint_hook.py index ab07dcb..2649de3 100644 --- a/edflow/hooks/train_hooks.py +++ b/edflow/hooks/checkpoint_hooks/tf_checkpoint_hook.py @@ -1,16 +1,68 @@ -import tensorflow as tf import os -import time - +import signal +import tensorflow as tf from edflow.hooks.hook import Hook -from edflow.hooks.evaluation_hooks import get_checkpoint_files +from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint from edflow.custom_logging import get_logger -from edflow.iterators.batches import plot_batch -import signal -import sys -"""TensorFlow hooks useful during training.""" +class RestoreModelHook(Hook): + """Restores a TensorFlow model from a checkpoint at each epoch. Can also + be used as a functor.""" + + def __init__( + self, + variables, + checkpoint_path, + filter_cond=lambda c: True, + global_step_setter=None, + ): + """Args: + variables (list): tf.Variable to be loaded from the checkpoint. + checkpoint_path (str): Directory in which the checkpoints are + stored or explicit checkpoint. Ignored if used as functor. + filter_cond (Callable): A function used to filter files, to only + get the checkpoints that are wanted. Ignored if used as + functor. + global_step_setter (Callable): Callback to set global_step. + """ + self.root = checkpoint_path + self.fcond = filter_cond + self.setstep = global_step_setter + + self.logger = get_logger(self) + + self.saver = tf.train.Saver(variables) + + @property + def session(self): + if not hasattr(self, "_session"): + self._session = tf.get_default_session() + return self._session + + def before_epoch(self, ep): + # checkpoint = tf.train.latest_checkpoint(self.root) + checkpoint = get_latest_checkpoint(self.root, self.fcond) + self(checkpoint) + + def __call__(self, checkpoint): + self.saver.restore(self.session, checkpoint) + self.logger.info("Restored model from {}".format(checkpoint)) + global_step = self.parse_global_step(checkpoint) + self.logger.info("Global step: {}".format(global_step)) + if self.setstep is not None: + self.setstep(global_step) + + @staticmethod + def parse_global_step(checkpoint): + global_step = int(checkpoint.rsplit("-", 1)[1]) + return global_step + + +# Simple renaming for consistency +# Todo: Make the Restore op part of the model (issue #2) +# https://bitbucket.org/jhaux/edflow/issues/2/make-a-general-model-restore-hook +RestoreTFModelHook = RestoreModelHook class CheckpointHook(Hook): @@ -86,78 +138,6 @@ def global_step(self): return global_step -class LoggingHook(Hook): - """Supply and evaluate logging ops at an intervall of training steps.""" - - def __init__( - self, - scalars={}, - histograms={}, - images={}, - logs={}, - graph=None, - interval=100, - root_path="logs", - ): - """Args: - scalars (dict): Scalar ops. - histograms (dict): Histogram ops. - images (dict): Image ops. Note that for these no - tensorboard logging ist used but a custom image saver. - logs (dict): Logs to std out via logger. - graph (tf.Graph): Current graph. - interval (int): Intervall of training steps before logging. - root_path (str): Path at which the logs are stored. - """ - - scalars = [tf.summary.scalar(n, s) for n, s in scalars.items()] - histograms = [tf.summary.histogram(n, h) for n, h in histograms.items()] - - self._has_summary = len(scalars + histograms) > 0 - if self._has_summary: - summary_op = tf.summary.merge(scalars + histograms) - else: - summary_op = tf.no_op() - - self.fetch_dict = {"summaries": summary_op, "logs": logs, "images": images} - - self.interval = interval - - self.graph = graph - self.root = root_path - self.logger = get_logger(self) - - def before_epoch(self, ep): - if ep == 0: - if self.graph is None: - self.graph = tf.get_default_graph() - - self.writer = tf.summary.FileWriter(self.root, self.graph) - - def before_step(self, batch_index, fetches, feeds, batch): - if batch_index % self.interval == 0: - fetches["logging"] = self.fetch_dict - - def after_step(self, batch_index, last_results): - if batch_index % self.interval == 0: - step = last_results["global_step"] - last_results = last_results["logging"] - if self._has_summary: - summary = last_results["summaries"] - self.writer.add_summary(summary, step) - - logs = last_results["logs"] - for name in sorted(logs.keys()): - self.logger.info("{}: {}".format(name, logs[name])) - - for name, image_batch in last_results["images"].items(): - full_name = name + "_{:07}.png".format(step) - save_path = os.path.join(self.root, full_name) - plot_batch(image_batch, save_path) - - self.logger.info("project root: {}".format(self.root)) - - class RetrainHook(Hook): """Restes the global step at the beginning of training.""" diff --git a/edflow/hooks/checkpoint_hooks/torch_checkpoint_hook.py b/edflow/hooks/checkpoint_hooks/torch_checkpoint_hook.py new file mode 100644 index 0000000..28ad192 --- /dev/null +++ b/edflow/hooks/checkpoint_hooks/torch_checkpoint_hook.py @@ -0,0 +1,61 @@ +import torch + + +class RestorePytorchModelHook(Hook): + """Restores a PyTorch model from a checkpoint at each epoch. Can also be + used as a functor.""" + + def __init__( + self, + model, + checkpoint_path, + filter_cond=lambda c: True, + global_step_setter=None, + ): + """Args: + model (torch.nn.Module): Model to initialize + checkpoint_path (str): Directory in which the checkpoints are + stored or explicit checkpoint. Ignored if used as functor. + filter_cond (Callable): A function used to filter files, to only + get the checkpoints that are wanted. Ignored if used as + functor. + global_step_setter (Callable): Function, that the retrieved global + step can be passed to. + """ + self.root = checkpoint_path + self.fcond = filter_cond + + self.logger = get_logger(self) + + self.model = model + self.global_step_setter = global_step_setter + + def before_epoch(self, ep): + checkpoint = get_latest_checkpoint(self.root, self.fcond) + self(checkpoint) + + def __call__(self, checkpoint): + self.model.load_state_dict(torch.load(checkpoint)) + self.logger.info("Restored model from {}".format(checkpoint)) + + epoch, step = self.parse_checkpoint(checkpoint) + + if self.global_step_setter is not None: + self.global_step_setter(step) + self.logger.info("Epoch: {}, Global step: {}".format(epoch, step)) + + @staticmethod + def parse_global_step(checkpoint): + return RestorePytorchModelHook.parse_checkpoint(checkpoint)[1] + + @staticmethod + def parse_checkpoint(checkpoint): + e_s = os.path.basename(checkpoint).split(".")[0].split("-") + if len(e_s) > 1: + epoch = e_s[0] + step = e_s[1].split("_")[0] + else: + epoch = 0 + step = e_s[0].split("_")[0] + + return int(epoch), int(step) diff --git a/edflow/hooks/hook.py b/edflow/hooks/hook.py index c27df41..d47da1f 100644 --- a/edflow/hooks/hook.py +++ b/edflow/hooks/hook.py @@ -1,5 +1,3 @@ -import tensorflow as tf - from edflow.custom_logging import get_default_logger @@ -86,71 +84,6 @@ def at_exception(self, exception): pass -class Hooker(object): - """Probably should rename that...""" - - def __init__(self, hooks, index, batch=None, session=None, logger=None): - """Args: - hooks (list): All :class:`Hook`s to be run before and after - this :class:`Hooker`. - index (int): step or epoch. - batch (list or dict): Feed dict when calling the hook. - session (tf.Session): Session object to run the :class:`Hook`s - with. - logger (logging.Logger): Logging log log logs. - """ - - if session is not None: - self.session = session - else: - self.session = tf.get_default_session() - - self.hooks = hooks - self.index = index - self.mode = "epoch" if batch is None else "step" - self.feeds = batch - - self.logger = logger or get_default_logger() - - self.step_op_results = None - - def __enter__(self): - """Run before-hooks.""" - self.last_results = [None] * len(self.hooks) - - for i, hook in enumerate(self.hooks): - method = getattr(hook, "before_{}".format(self.mode)) - - fetch_args = [self.index] - if self.mode == "step": - fetch_args += [self.feeds] - - fetches = method(*fetch_args) - - if fetches is not None: - self.last_results[i] = self.session.run(fetches, feed_dict=self.feeds) - return self - - def __exit__(self, *args, **kwargs): - """Run after-hooks.""" - for i, hook in enumerate(self.hooks): - method = getattr(hook, "after_{}".format(self.mode)) - - fetch_args = [self.index] - if self.mode == "step": - fetch_args += [self.feeds, self.step_op_results] - fetch_args += [self.last_results[i]] - - fetches = method(*fetch_args) - - if fetches is not None: - self.session.run(fetches, feed_dict=self.feeds) - - def set_step_op_results(self, results): - """Enter results for bookkeeping.""" - self.step_op_results = results - - def match_frequency(global_hook_frequency, local_hook_frequency): r"""Given the global frequency at which hooks are evaluated matches the local frequency at which a hook wants to be evaluated s.t. it will diff --git a/edflow/hooks/logging_hooks/__init__.py b/edflow/hooks/logging_hooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/edflow/hooks/logging_hooks/tf_logging_hook.py b/edflow/hooks/logging_hooks/tf_logging_hook.py new file mode 100644 index 0000000..87dfebc --- /dev/null +++ b/edflow/hooks/logging_hooks/tf_logging_hook.py @@ -0,0 +1,85 @@ +import tensorflow as tf +import os +import time + +from edflow.hooks.hook import Hook +from edflow.hooks.checkpoint_hooks.common import get_checkpoint_files +from edflow.custom_logging import get_logger +from edflow.iterators.batches import plot_batch + +import signal +import sys + +"""TensorFlow hooks useful during training.""" + + +class LoggingHook(Hook): + """Supply and evaluate logging ops at an intervall of training steps.""" + + def __init__( + self, + scalars={}, + histograms={}, + images={}, + logs={}, + graph=None, + interval=100, + root_path="logs", + ): + """Args: + scalars (dict): Scalar ops. + histograms (dict): Histogram ops. + images (dict): Image ops. Note that for these no + tensorboard logging ist used but a custom image saver. + logs (dict): Logs to std out via logger. + graph (tf.Graph): Current graph. + interval (int): Intervall of training steps before logging. + root_path (str): Path at which the logs are stored. + """ + + scalars = [tf.summary.scalar(n, s) for n, s in scalars.items()] + histograms = [tf.summary.histogram(n, h) for n, h in histograms.items()] + + self._has_summary = len(scalars + histograms) > 0 + if self._has_summary: + summary_op = tf.summary.merge(scalars + histograms) + else: + summary_op = tf.no_op() + + self.fetch_dict = {"summaries": summary_op, "logs": logs, "images": images} + + self.interval = interval + + self.graph = graph + self.root = root_path + self.logger = get_logger(self) + + def before_epoch(self, ep): + if ep == 0: + if self.graph is None: + self.graph = tf.get_default_graph() + + self.writer = tf.summary.FileWriter(self.root, self.graph) + + def before_step(self, batch_index, fetches, feeds, batch): + if batch_index % self.interval == 0: + fetches["logging"] = self.fetch_dict + + def after_step(self, batch_index, last_results): + if batch_index % self.interval == 0: + step = last_results["global_step"] + last_results = last_results["logging"] + if self._has_summary: + summary = last_results["summaries"] + self.writer.add_summary(summary, step) + + logs = last_results["logs"] + for name in sorted(logs.keys()): + self.logger.info("{}: {}".format(name, logs[name])) + + for name, image_batch in last_results["images"].items(): + full_name = name + "_{:07}.png".format(step) + save_path = os.path.join(self.root, full_name) + plot_batch(image_batch, save_path) + + self.logger.info("project root: {}".format(self.root)) diff --git a/edflow/hooks/metric_hooks/__init__.py b/edflow/hooks/metric_hooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/edflow/hooks/metric_hooks/tf_metric_hook.py b/edflow/hooks/metric_hooks/tf_metric_hook.py new file mode 100644 index 0000000..4ba523d --- /dev/null +++ b/edflow/hooks/metric_hooks/tf_metric_hook.py @@ -0,0 +1,85 @@ +class MetricHook(Hook): + """Applies a set of given metrics to the calculated data.""" + + def __init__(self, metrics, save_root, consider_only_first=None): + """Args: + metrics (list): List of ``MetricTuple``s of the form + ``(input names, output names, metric, name)``. + - ``input names`` are the keys corresponding to the feeds of + interest, e.g. an original image. + - ``output names`` are the keys corresponding to the values + in the results dict. + - ``metric`` is a ``Callable`` that accepts all inputs and + outputs keys as keyword arguments + - ``name`` is a + If nested feeds or results are expected the names can be + passed as "path" like ``'key1_key2'`` returning + ``dict[key1][key2]``. + save_root (str): Path to where the results are stored. + consider_only_first (int): Metric is only evaluated on the first + `consider_only_first` examples. + """ + + self.metrics = metrics + + self.root = save_root + self.logger = get_logger(self, "latest_eval") + + self.max_step = consider_only_first + + self.storage_dict = {} + self.metric_results = {} + for m in metrics: + test_valid_metrictuple(m) + + self.tb_saver = tf.summary.FileWriter(self.root) + + def before_epoch(self, epoch): + self.count = 0 + for m in self.metrics: + self.metric_results[m.name] = [] + + def before_step(self, step, fetches, feeds, batch): + if self.max_step is not None and self.count >= self.max_step: + return + + for in_names, out_names, metric, m_name in self.metrics: + self.storage_dict[m_name] = {} + for kwargs_name, name in in_names.items(): + val = retrieve(name, batch) + self.storage_dict[m_name][kwargs_name] = val + + def after_step(self, step, results): + if self.max_step is not None and self.count >= self.max_step: + return + + for in_names, out_names, metric, m_name in self.metrics: + for kwargs_name, name in out_names.items(): + val = retrieve(name, results) + self.storage_dict[m_name][kwargs_name] = val + m_res = metric(**self.storage_dict[m_name]) + self.metric_results[m_name] += [m_res] + + self.global_step = results["global_step"] + self.count += 1 + + def after_epoch(self, epoch): + self.logger.info("Metrics at epoch {}:".format(epoch)) + + mean_results = {} + for name, result in self.metric_results.items(): + results = np.concatenate(result) + mean = np.mean(results, axis=0) + var = np.std(results, axis=0) + mean_results[name] = np.array([mean, var]) + self.logger.info("{}: {} +- {}".format(name, mean, var)) + + summary = tf.Summary() + summary_mean = mean if len(mean.shape) == 0 else mean[0] + summary.value.add(tag=name, simple_value=summary_mean) + self.tb_saver.add_summary(summary, self.global_step) + self.tb_saver.flush() + + name = "{:0>6d}_metrics".format(self.global_step) + name = os.path.join(self.root, name) + np.savez_compressed(name, **mean_results) diff --git a/edflow/iterators/evaluator.py b/edflow/iterators/tf_evaluator.py similarity index 100% rename from edflow/iterators/evaluator.py rename to edflow/iterators/tf_evaluator.py diff --git a/edflow/iterators/tf_trainer.py b/edflow/iterators/tf_trainer.py index d6a72c2..fa8b9ac 100644 --- a/edflow/iterators/tf_trainer.py +++ b/edflow/iterators/tf_trainer.py @@ -1,13 +1,18 @@ import tensorflow as tf from edflow.iterators.tf_iterator import HookedModelIterator, TFHookedModelIterator -from edflow.hooks import LoggingHook, CheckpointHook, RetrainHook -from edflow.hooks import match_frequency -from edflow.project_manager import ProjectManager -from edflow.util import make_linear_var +from edflow.hooks.hook import match_frequency +from edflow.hooks.logging_hooks.tf_logging_hook import LoggingHook +from edflow.hooks.checkpoint_hooks.tf_checkpoint_hook import ( + CheckpointHook, + RetrainHook, + RestoreTFModelHook, +) +from edflow.tf_util import make_linear_var + +from edflow.project_manager import ProjectManager from edflow.hooks.util_hooks import IntervalHook -from edflow.hooks.evaluation_hooks import RestoreTFModelHook P = ProjectManager() diff --git a/edflow/model.py b/edflow/model.py deleted file mode 100644 index 68c5f39..0000000 --- a/edflow/model.py +++ /dev/null @@ -1,23 +0,0 @@ -import tensorflow as tf - - -class Model(object): - """Base class defining all neccessary model methods.""" - - def __init__(self, name): - self.model_name = name - - @property - def inputs(self): - """Input placeholders""" - raise NotImplementedError() - - @property - def outputs(self): - """Output tensors.""" - raise NotImplementedError() - - @property - def variables(self): - """Variables.""" - return [v for v in tf.global_variables() if self.model_name in v.name] diff --git a/edflow/tf_util.py b/edflow/tf_util.py new file mode 100644 index 0000000..ce5b927 --- /dev/null +++ b/edflow/tf_util.py @@ -0,0 +1,52 @@ +import tensorflow as tf + + +def make_linear_var( + step, start, end, start_value, end_value, clip_min=None, clip_max=None +): + r"""Linear from :math:`(a, \alpha)` to :math:`(b, \beta)`, i.e. + :math:`y = (\beta - \alpha)/(b - a) * (x - a) + \alpha` + + Args: + step (tf.Tensor): :math:`x` + start: :math:`a` + end: :math:`b` + start_value: :math:`\alpha` + end_value: :math:`\beta` + clip_min: Minimal value returned. + clip_max: Maximum value returned. + + Returns: + tf.Tensor: :math:`y` + """ + + if clip_min is None: + clip_min = min(start_value, end_value) + if clip_max is None: + clip_max = max(start_value, end_value) + linear = (end_value - start_value) / (end - start) * ( + tf.cast(step, tf.float32) - start + ) + start_value + return tf.clip_by_value(linear, clip_min, clip_max) + + +def make_exponential_var(step, start, end, start_value, end_value, decay): + r"""Exponential from :math:`(a, \alpha)` to :math:`(b, \beta)` with decay + rate decay. + + Args: + step (tf.Tensor): :math:`x` + start: :math:`a` + end: :math:`b` + start_value: :math:`\alpha` + end_value: :math:`\beta` + decay: Decay rate + + Returns: + tf.Tensor: :math:`y` + """ + + startstep = start + endstep = (np.log(end_value) - np.log(start_value)) / np.log(decay) + stepper = make_linear_var(step, start, end, startstep, endstep) + return tf.math.pow(decay, stepper) * start_value diff --git a/edflow/train.py b/edflow/train.py deleted file mode 100644 index e3c43ef..0000000 --- a/edflow/train.py +++ /dev/null @@ -1,67 +0,0 @@ -import tensorflow as tf - -import argparse -import numpy as np -import glob -import os -import yaml -from tqdm import tqdm, trange - -import multiprocessing as mp - -from edflow.custom_logging import init_logging, get_logger -from edflow.main import train, test - - -def main(opt): - with open(opt.config) as f: - config = yaml.load(f) - - out_dir = init_logging("logs") - logger = get_logger("main_training") - logger.info(opt) - logger.info(yaml.dump(config)) - - if not opt.doeval: - train(config, out_dir, opt.checkpoint, opt.retrain) - else: - train_process = mp.Process( - target=train, args=(config, out_dir, opt.checkpoint, opt.retrain) - ) - test_process = mp.Process(target=test, args=(config, out_dir)) - - processes = [train_process, test_process] - - try: - for p in processes: - p.start() - - for p in processes: - p.join() - - except KeyboardInterrupt: - logger.info("Terminating all processes") - for p in processes: - p.terminate() - finally: - logger.info("Finished") - - -if __name__ == "__main__": - default_log_dir = os.path.join(os.getcwd(), "log") - - parser = argparse.ArgumentParser() - parser.add_argument("--config", required=True, help="path to config") - parser.add_argument("--checkpoint", help="path to checkpoint to restore") - parser.add_argument( - "--doeval", action="store_true", default=False, help="only run training" - ) - parser.add_argument( - "--retrain", - action="store_true", - default=False, - help="reset global_step to zero", - ) - - opt = parser.parse_args() - main(opt) diff --git a/edflow/util.py b/edflow/util.py index 5706ad6..bf82e91 100644 --- a/edflow/util.py +++ b/edflow/util.py @@ -6,36 +6,6 @@ import pickle -def make_linear_var( - step, start, end, start_value, end_value, clip_min=None, clip_max=None -): - r"""Linear from :math:`(a, \alpha)` to :math:`(b, \beta)`, i.e. - :math:`y = (\beta - \alpha)/(b - a) * (x - a) + \alpha` - - Args: - step (tf.Tensor): :math:`x` - start: :math:`a` - end: :math:`b` - start_value: :math:`\alpha` - end_value: :math:`\beta` - clip_min: Minimal value returned. - clip_max: Maximum value returned. - - Returns: - tf.Tensor: :math:`y` - """ - import tensorflow as tf - - if clip_min is None: - clip_min = min(start_value, end_value) - if clip_max is None: - clip_max = max(start_value, end_value) - linear = (end_value - start_value) / (end - start) * ( - tf.cast(step, tf.float32) - start - ) + start_value - return tf.clip_by_value(linear, clip_min, clip_max) - - def linear_var(step, start, end, start_value, end_value, clip_min=0.0, clip_max=1.0): r"""Linear from :math:`(a, \alpha)` to :math:`(b, \beta)`, i.e. :math:`y = (\beta - \alpha)/(b - a) * (x - a) + \alpha` @@ -58,29 +28,6 @@ def linear_var(step, start, end, start_value, end_value, clip_min=0.0, clip_max= return float(np.clip(linear, clip_min, clip_max)) -def make_exponential_var(step, start, end, start_value, end_value, decay): - r"""Exponential from :math:`(a, \alpha)` to :math:`(b, \beta)` with decay - rate decay. - - Args: - step (tf.Tensor): :math:`x` - start: :math:`a` - end: :math:`b` - start_value: :math:`\alpha` - end_value: :math:`\beta` - decay: Decay rate - - Returns: - tf.Tensor: :math:`y` - """ - import tensorflow as tf - - startstep = start - endstep = (np.log(end_value) - np.log(start_value)) / np.log(decay) - stepper = make_linear_var(step, start, end, startstep, endstep) - return tf.math.pow(decay, stepper) * start_value - - def walk(dict_or_list, fn, inplace=False, pass_key=False, prev_key=""): # noqa """Walk a nested list and/or dict recursively and call fn on all non list or dict objects. diff --git a/environments/edflow_cpu.yaml b/environments/edflow_cpu.yaml index 0db4267..baad35c 100644 --- a/environments/edflow_cpu.yaml +++ b/environments/edflow_cpu.yaml @@ -6,6 +6,5 @@ dependencies: - python=3.6 - pytorch-cpu=1.0.1 - torchvision-cpu=0.2.2 - - pip: - - tensorflow==1.12 + - tensorflow==1.12 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9c4aa77 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +exclude = 'logs' \ No newline at end of file