diff --git a/.travis.yml b/.travis.yml index f3146d4..80bc0be 100644 --- a/.travis.yml +++ b/.travis.yml @@ -44,6 +44,7 @@ jobs: - pip install tensorboardX - python setup.py install script: + - python -c "import skimage" # fix failing test due to static tls loading - pytest --ignore=examples # use pytest instead of python -m pytest - stage: tf_example_tests name: "Tensorflow examples test" diff --git a/README.md b/README.md index f373cea..76ebac9 100644 --- a/README.md +++ b/README.md @@ -1,80 +1,422 @@ -# EDFlow - Evaluation Driven workFlow +# edflow -A small framework for training and evaluating tensorflow models by Mimo Tilbich. +A framework independent engine for training and evaluating in batches. ## Table of Contents -1. [Setup](#Setup) -2. [Workflow](#Workflow) -3. [Example](#Example) -4. [Other](#Other) - 1. [Parameters](#Parameters) - 2. [Known Issues](#Known-Issues) - 3. [Compatibility](#Compatibility) +1. [Installation](#Installation) +2. [Getting Started](#Getting-Started) + 1. [TensorFlow Eager](#TensorFlow-Eager) + 2. [PyTorch](#PyTorch) + 3. [TensorFlow Graph-Building](#TensorFlow-Graph-Building) +3. [Documentation](#Documentation) +4. [Command-Line Parameters](#Command-Line-Parameters) 5. [Contributions](#Contributions) 6. [LICENSE](#LICENSE) 7. [Authors](#Authors) -## Setup -Clone this repository: +## Installation git clone https://github.com/pesser/edflow.git cd edflow + pip install . -We provide different [conda](https://conda.io) environments in the folder -`environments`: -- `edflow_tf_cu9.yaml`: Use if you have `CUDA>=9` available and - want to use tensorflow. -- `edflow_pt_cu9.yaml`: Use if you have `CUDA>=9` available and - want to use pytorch. -- `edflow_cpu`: Use if you don't have a `CUDA>=9` GPU available. +## Getting Started -Choose an appropriate environment and execute - conda env create -f environments/.yaml - conda activate - pip install -e . + cd examples -where `` is one of the `yaml` files described above. +### TensorFlow Eager + +You provide an implementation of a model and an iterator and use `edflow` to +train and evaluate your model. An example can be found in +`template_tfe/edflow.py`: + +```python +import functools +import tensorflow as tf + +tf.enable_eager_execution() +import tensorflow.keras as tfk +import numpy as np +from edflow import TemplateIterator, get_logger + + +class Model(tfk.Model): + def __init__(self, config): + super().__init__() + self.conv1 = tfk.layers.Conv2D(filters=6, kernel_size=5) + self.pool = tfk.layers.MaxPool2D(pool_size=2, strides=2) + self.conv2 = tfk.layers.Conv2D(filters=16, kernel_size=5) + self.fc1 = tfk.layers.Dense(units=120) + self.fc2 = tfk.layers.Dense(units=84) + self.fc3 = tfk.layers.Dense(units=config["n_classes"]) + + input_shape = (config["batch_size"], 28, 28, 1) + self.build(input_shape) + + def call(self, x): + x = self.pool(tf.nn.relu(self.conv1(x))) + x = self.pool(tf.nn.relu(self.conv2(x))) + x = tf.reshape(x, [tf.shape(x)[0], -1]) + x = tf.nn.relu(self.fc1(x)) + x = tf.nn.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class Iterator(TemplateIterator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # loss and optimizer + self.criterion = functools.partial( + tfk.losses.sparse_categorical_crossentropy, from_logits=True + ) + self.optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9) + # to save and restore + self.tfcheckpoint = tf.train.Checkpoint( + model=self.model, optimizer=self.optimizer + ) + + def save(self, checkpoint_path): + self.tfcheckpoint.write(checkpoint_path) + + def restore(self, checkpoint_path): + self.tfcheckpoint.restore(checkpoint_path) + + def step_op(self, model, **kwargs): + # get inputs + inputs, labels = kwargs["image"], kwargs["class"] + + # compute loss + with tf.GradientTape() as tape: + outputs = model(inputs) + loss = self.criterion(y_true=labels, y_pred=outputs) + mean_loss = tf.reduce_mean(loss) + + def train_op(): + grads = tape.gradient(mean_loss, model.trainable_variables) + self.optimizer.apply_gradients(zip(grads, model.trainable_variables)) + + def log_op(): + acc = np.mean(np.argmax(outputs, axis=1) == labels) + min_loss = np.min(loss) + max_loss = np.max(loss) + return { + "images": {"inputs": inputs}, + "scalars": { + "min_loss": min_loss, + "max_loss": max_loss, + "mean_loss": mean_loss, + "acc": acc, + }, + } + + def eval_op(): + return {"outputs": np.array(outputs), "loss": np.array(loss)[:, None]} -## Workflow + return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op} +``` -For more information, look into our [documentation](https://edflow.readthedocs.io/en/latest/). +Specify your parameters in a `yaml` config file, e.g. +`template_tfe/config.yaml`: + +```yaml +dataset: edflow.datasets.fashionmnist.FashionMNIST +model: template_tfe.edflow.Model +iterator: template_tfe.edflow.Iterator +batch_size: 4 +num_epochs: 2 + +n_classes: 10 +``` + +#### Train +To start training, use the `-t/--train ` command-line option and, +optionally, the `-n/--name ` option to more easily find your experiments +later on: + + +``` +$ edflow -t template_tfe/config.yaml -n hello_tfe +[INFO] [train]: Starting Training. +[INFO] [train]: Instantiating dataset. +[INFO] [FashionMNIST]: Using split: train +[INFO] [train]: Number of training samples: 60000 +[INFO] [train]: Warm up batches. +[INFO] [train]: Reset batches. +[INFO] [train]: Instantiating model. +[INFO] [train]: Instantiating iterator. +[INFO] [train]: Initializing model. +[INFO] [train]: Starting Training with config: +batch_size: 4 +dataset: edflow.datasets.fashionmnist.FashionMNIST +hook_freq: 1 +iterator: template_tfe.edflow.Iterator +model: template_tfe.edflow.Model +n_classes: 10 +num_epochs: 2 +num_steps: 30000 + +[INFO] [train]: Saved config at logs/2019-08-05T18:55:20_hello_tfe/configs/train_2019-08-05T18:55:26.yaml +[INFO] [train]: Iterating. +[INFO] [LoggingHook]: global_step: 0 +[INFO] [LoggingHook]: acc: 0.25 +[INFO] [LoggingHook]: max_loss: 2.3287339210510254 +[INFO] [LoggingHook]: mean_loss: 2.256807565689087 +[INFO] [LoggingHook]: min_loss: 2.2113394737243652 +[INFO] [LoggingHook]: project root: logs/2019-08-05T18:55:20_hello_tfe/train +... +``` + +edflow shows the progress of your training and scalar logging values. The log +file, log outputs and checkpoints can be found in the `train` folder of the +project root at `logs/2019-08-05T18:55:20_hello_tfe/`. By default, checkpoints +are written after each epoch, or when an exception is encountered, including +a `KeyboardInterrupt`. The checkpoint frequency can be adjusted with a +`ckpt_freq: ` entry in the config file. All config file entries can +also be specified on the command line as, e.g., `--ckpt_freq `. +#### Interrupt and Resume +Use `CTRL-C` to interrupt the training: + + + [INFO] [LambdaCheckpointHook]: Saved model to logs/2019-08-05T18:55:20_hello_tfe/train/checkpoints/model-1207.ckpt + +To resume training, run + + + edflow -t template_tfe/config.yaml -p logs/2019-08-05T18:55:20_hello_tfe/ + + +It will load the last checkpoint in the project folder and continue training +and logging into the same folder. +This lets you easily adjust parameters without having to start training from +scratch, e.g. + + + edflow -t template_tfe/config.yaml -p logs/2019-08-05T18:55:20_hello_tfe/ --batch_size 32 + + +will continue with an increased batch size. Instead of loading the latest +checkpoint, you can load a specific checkpoint by adding `-c `: + + + edflow -t template_tfe/config.yaml -p logs/2019-08-05T18:55:20_hello_tfe/ -c logs/2019-08-05T18:55:20_hello_tfe/train/checkpoints/model-1207.ckpt + + +#### Evaluate +Evaluation mode will write all outputs of `eval_op` to disk and prepare them +for consumption by your evaluation functions. Just replace `-t` by `-e`: + + + edflow -e template_tfe/config.yaml -p logs/2019-08-05T18:55:20_hello_tfe/ -c logs/2019-08-05T18:55:20_hello_tfe/train/checkpoints/model-1207.ckpt + + +If `-c` is not specified, it will evaluate the latest checkpoint. The +evaluation mode will finish with + +``` +[INFO] [EvalHook]: All data has been produced. You can now also run all callbacks using the following command: +edeval -c logs/2019-08-05T18:55:20_hello_tfe/eval/2019-08-05T19:22:23/1207/model_output.csv -cb +``` + +Your callbacks will get the path to the evaluation folder, the input dataset as +seen by your model, an output dataset which contains the corresponding outputs +of your model and the config used for evaluation. `template_tfe/edflow.py` +contains an example callback computing the average loss and accuracy: + +```python +def acc_callback(root, data_in, data_out, config): + from tqdm import trange + + logger = get_logger("acc_callback") + correct = 0 + seen = 0 + loss = 0.0 + for i in trange(len(data_in)): + labels = data_in[i]["class"] + outputs = data_out[i]["outputs"] + loss = data_out[i]["loss"].squeeze() + + prediction = np.argmax(outputs, axis=0) + correct += labels == prediction + loss += loss + logger.info("Loss: {}".format(loss / len(data_in))) + logger.info("Accuracy: {}".format(correct / len(data_in))) +``` + +which can be executed with: + + +``` +$ edeval -c logs/2019-08-05T18:55:20_hello_tfe/eval/2019-08-05T19:22:23/1207/model_output.csv -cb template_tfe.edflow.acc_callback +... +INFO:acc_callback:Loss: 0.00013115551471710204 +INFO:acc_callback:Accuracy: 0.7431 +``` + +### PyTorch + +The same example as implemented by [TensorFlow Eager](#TensorFlow-Eager), can +be found for PyTorch in `template_pytorch/edflow.py` and requires only slightly +different syntax: -## Example +```python +import functools +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F -### Tensorflow +import numpy as np +from edflow import TemplateIterator, get_logger - cd examples - edflow -t mnist_tf/train.yaml -n hello_tensorflow +class Model(nn.Module): + def __init__(self, config): + super().__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 4 * 4, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, config["n_classes"]) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(x.shape[0], -1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x -### Pytorch + +class Iterator(TemplateIterator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # loss and optimizer + self.criterion = nn.CrossEntropyLoss(reduction="none") + self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + + def save(self, checkpoint_path): + state = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + } + torch.save(state, checkpoint_path) + + def restore(self, checkpoint_path): + state = torch.load(checkpoint_path) + self.model.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + + def step_op(self, model, **kwargs): + # get inputs + inputs, labels = kwargs["image"], kwargs["class"] + inputs = torch.tensor(inputs) + inputs = inputs.transpose(2, 3).transpose(1, 2) + labels = torch.tensor(labels, dtype=torch.long) + + # compute loss + outputs = model(inputs) + loss = self.criterion(outputs, labels) + mean_loss = torch.mean(loss) + + def train_op(): + self.optimizer.zero_grad() + mean_loss.backward() + self.optimizer.step() + + def log_op(): + acc = np.mean( + np.argmax(outputs.detach().numpy(), axis=1) == labels.detach().numpy() + ) + min_loss = np.min(loss.detach().numpy()) + max_loss = np.max(loss.detach().numpy()) + return { + "images": {"inputs": inputs.detach().numpy()}, + "scalars": { + "min_loss": min_loss, + "max_loss": max_loss, + "mean_loss": mean_loss, + "acc": acc, + }, + } + + def eval_op(): + return { + "outputs": np.array(outputs.detach().numpy()), + "loss": np.array(loss.detach().numpy())[:, None], + } + + return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op} +``` + +You can experiment with it in the exact same way as [above](#TensorFlow-Eager). +For example, to [start training](#Train) run: + + + edflow -t template_tfe/config.yaml -n hello_pytorch + + +See also [interrupt and resume](#interrupt-and-resume) and +[evaluation](#Evaluate). + + +### TensorFlow Graph-Building + +edflow also supports graph-based execution, e.g. cd examples - edflow -t mnist_pytorch/mnist_config.yaml -n hello_pytorch + edflow -t mnist_tf/train.yaml -n hello_tensorflow +With TensorFlow 2.x going eager by default and TensorFlow 1.x supporting eager +execution, support for TensorFlow's 1.x graph +building will fade away. -## Other -### Parameters -- `--config path/to/config` - yaml file with all information see [Workflow][#Workflow] -- `--checkpoint path/to/checkpoint to restore` +## Documentation -- `--noeval` - only run training +For more information, look into our [documentation](https://edflow.readthedocs.io/en/latest/). -- `--retrain` - reset global step to zero -### Known Issues +## Command-Line Parameters + +``` +$ edflow --help +usage: edflow [-h] [-n description] + [-b [base_config.yaml [base_config.yaml ...]]] [-t config.yaml] + [-e [config.yaml [config.yaml ...]]] [-p PROJECT] + [-c CHECKPOINT] [-r] [-log LEVEL] + +optional arguments: + -h, --help show this help message and exit + -n description, --name description + postfix of log directory. + -b [base_config.yaml [base_config.yaml ...]], --base [base_config.yaml [base_config.yaml ...]] + Path to base config. Any parameter in here is + overwritten by the train of eval config. Useful e.g. + for model parameters, which stay constant between + trainings and evaluations. + -t config.yaml, --train config.yaml + path to training config + -e [config.yaml [config.yaml ...]], --eval [config.yaml [config.yaml ...]] + path to evaluation configs + -p PROJECT, --project PROJECT + path to existing project + -c CHECKPOINT, --checkpoint CHECKPOINT + path to existing checkpoint + -r, --retrain reset global step + -log LEVEL, --log-level LEVEL + Set the std-out logging level. +``` -### Compatibility ## Contributions [![GitHub-Commits][GitHub-Commits]](https://github.com/pesser/edflow/graphs/commit-activity) diff --git a/edflow/__init__.py b/edflow/__init__.py index e69de29..bccb4fc 100644 --- a/edflow/__init__.py +++ b/edflow/__init__.py @@ -0,0 +1,11 @@ +# project management +from edflow.project_manager import ProjectManager +from edflow.custom_logging import get_logger +from edflow.main import get_obj_from_str + +# iterators +from edflow.iterators.model_iterator import PyHookedModelIterator +from edflow.iterators.template_iterator import TemplateIterator + +# hook +from edflow.hooks.hook import Hook diff --git a/edflow/eval/pipeline.py b/edflow/eval/pipeline.py index e5c38ea..7277ed1 100644 --- a/edflow/eval/pipeline.py +++ b/edflow/eval/pipeline.py @@ -106,7 +106,7 @@ def eval_op(self, inputs): import re from edflow.data.util import adjust_support -from edflow.util import walk +from edflow.util import walk, retrieve from edflow.data.dataset import DatasetMixin, CsvDataset, ProcessedDataset from edflow.project_manager import ProjectManager as P from edflow.hooks.hook import Hook @@ -127,6 +127,7 @@ def __init__( callbacks=[], meta=None, step_getter=None, + keypath="step_ops", ): """ .. warning:: @@ -155,6 +156,8 @@ def __init__( ``yaml``. Usually the ``edflow`` config. step_getter : Callable Function which returns the global step as ``int``. + keypath : str + Path in result which will be stored. """ self.logger = get_logger(self) @@ -168,6 +171,7 @@ def __init__( self.meta = meta self.gs = step_getter + self.keypath = keypath def before_epoch(self, epoch): """ @@ -249,7 +253,13 @@ def after_step(self, step, last_results): for i, idx in enumerate(idxs): self.label_arrs[k][idx] = label_vals[k][i] - path_dicts = save_output(self.save_root, last_results, idxs, self.sdks) + path_dicts = save_output( + root=self.save_root, + example=last_results, + index=idxs, + sub_dir_keys=self.sdks, + keypath=self.keypath, + ) if self.data_frame is None: columns = sorted(path_dicts[list(path_dicts.keys())[0]]) @@ -328,6 +338,32 @@ def save_csv(self): ) +class TemplateEvalHook(EvalHook): + """EvalHook that disables itself when the eval op returns None.""" + + def before_epoch(self, *args, **kwargs): + self._active = True + super().before_epoch(*args, **kwargs) + + def before_step(self, *args, **kwargs): + if self._active: + super().before_step(*args, **kwargs) + + def after_step(self, step, last_results): + if retrieve(last_results, self.keypath) is None: + self._active = False + if self._active: + super().after_step(step, last_results) + + def after_epoch(self, *args, **kwargs): + if self._active: + super().after_epoch(*args, **kwargs) + + def at_exception(self, *args, **kwargs): + if self._active: + super().at_exception(*args, **kwargs) + + class EvalDataFolder(DatasetMixin): """ """ @@ -408,7 +444,7 @@ def load_labels(root): return labels -def save_output(root, example, index, sub_dir_keys=[]): +def save_output(root, example, index, sub_dir_keys=[], keypath="step_ops"): """Saves the ouput of some model contained in ``example`` in a reusable manner. @@ -442,7 +478,7 @@ def save_output(root, example, index, sub_dir_keys=[]): """ - example = example["step_ops"] + example = retrieve(example, keypath) sub_dirs = [""] * len(index) for subk in sub_dir_keys: @@ -551,8 +587,9 @@ def _delget(d, k): def save_example(savepath, datum): - """Manages the writing process of a single datum: (1) Determine type, - (2) Choos saver, (3) save. + """ + Manages the writing process of a single datum: (1) Determine type, + (2) Choose saver, (3) save. Parameters ---------- @@ -626,9 +663,7 @@ def load_by_heuristic(path): elif ext == ".txt": return txt_loader(path) else: - raise ValueError( - "Cannot load file with extenstion `{}` at {}".format(ext, path) - ) + raise ValueError("Cannot load file with extension `{}` at {}".format(ext, path)) def decompose_name(name): diff --git a/edflow/hooks/logging_hooks/minimal_logging_hook.py b/edflow/hooks/logging_hooks/minimal_logging_hook.py new file mode 100644 index 0000000..18852ce --- /dev/null +++ b/edflow/hooks/logging_hooks/minimal_logging_hook.py @@ -0,0 +1,50 @@ +from edflow.hooks.hook import Hook +from edflow.util import retrieve +from edflow.custom_logging import get_logger +from edflow.iterators.batches import plot_batch +import os + + +class LoggingHook(Hook): + """Minimal implementation of a logging hook. Can be easily extended by + adding handlers.""" + + def __init__(self, paths, interval, root_path): + """ + Parameters + ---------- + paths : list(str) + List of key-paths to logging outputs. Will be + expanded so they can be evaluated lazily. + interval : int + Intervall of training steps before logging. + root_path : str + Path at which the logs are stored. + """ + self.paths = paths + self.interval = interval + self.root = root_path + self.logger = get_logger(self) + self.handlers = {"images": self.log_images, "scalars": self.log_scalars} + + def after_step(self, batch_index, last_results): + if batch_index % self.interval == 0: + self._step = last_results["global_step"] + self.logger.info("global_step: {}".format(self._step)) + for path in self.paths: + for k in self.handlers: + handler_results = retrieve( + last_results, path + "/" + k, default=dict() + ) + self.handlers[k](handler_results) + self.logger.info("project root: {}".format(self.root)) + + def log_scalars(self, results): + for name in sorted(results.keys()): + self.logger.info("{}: {}".format(name, results[name])) + + def log_images(self, results): + for name, image_batch in results.items(): + full_name = name + "_{:07}.png".format(self._step) + save_path = os.path.join(self.root, full_name) + plot_batch(image_batch, save_path) diff --git a/edflow/iterators/template_iterator.py b/edflow/iterators/template_iterator.py new file mode 100644 index 0000000..0f36025 --- /dev/null +++ b/edflow/iterators/template_iterator.py @@ -0,0 +1,94 @@ +from edflow.iterators.model_iterator import PyHookedModelIterator +from edflow.hooks.checkpoint_hooks.lambda_checkpoint_hook import LambdaCheckpointHook +from edflow.hooks.logging_hooks.minimal_logging_hook import LoggingHook +from edflow.hooks.util_hooks import IntervalHook +from edflow.eval.pipeline import TemplateEvalHook +from edflow.project_manager import ProjectManager +from edflow.util import retrieve +from edflow.main import get_obj_from_str + + +class TemplateIterator(PyHookedModelIterator): + """A specialization of PyHookedModelIterator which adds reasonable default + behaviour. Subclasses should implement `save`, `restore` and `step_op`.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # wrap save and restore into a LambdaCheckpointHook + self.ckpthook = LambdaCheckpointHook( + root_path=ProjectManager.checkpoints, + global_step_getter=self.get_global_step, + global_step_setter=self.set_global_step, + save=self.save, + restore=self.restore, + interval=self.config.get("ckpt_freq", None), + ) + if not self.config.get("test_mode", False): + # in training, excute train ops and add logginghook + self._train_ops = self.config.get("train_ops", ["step_ops/train_op"]) + self._log_ops = self.config.get("log_ops", ["step_ops/log_op"]) + # logging + self.loghook = LoggingHook( + paths=self._log_ops, root_path=ProjectManager.train, interval=1 + ) + # wrap it in interval hook + self.ihook = IntervalHook( + [self.loghook], + interval=self.config.get("start_log_freq", 1), + modify_each=1, + max_interval=self.config.get("log_freq", 1000), + get_step=self.get_global_step, + ) + self.hooks.append(self.ihook) + # write checkpoints after epoch or when interrupted + self.hooks.append(self.ckpthook) + else: + # evaluate + self._eval_op = self.config.get("eval_op", "step_ops/eval_op") + self._eval_callbacks = self.config.get("eval_callbacks", list()) + if not isinstance(self._eval_callbacks, list): + self._eval_callbacks = [self._eval_callbacks] + self._eval_callbacks = [ + get_obj_from_str(name) for name in self._eval_callbacks + ] + self.evalhook = TemplateEvalHook( + dataset=self.dataset, + step_getter=self.get_global_step, + keypath=self._eval_op, + meta=self.config, + callbacks=self._eval_callbacks, + ) + self.hooks.append(self.evalhook) + self._train_ops = [] + self._log_ops = [] + + def initialize(self, checkpoint_path=None): + if checkpoint_path is not None: + self.ckpthook(checkpoint_path) + + def step_ops(self): + return self.step_op + + def run(self, fetches, feed_dict): + results = super().run(fetches, feed_dict) + for train_op in self._train_ops: + retrieve(results, train_op) + return results + + def save(self, checkpoint_path): + """Save state to checkpoint path.""" + raise NotImplemented() + + def restore(self, checkpoint_path): + """Restore state from checkpoint path.""" + raise NotImplemented() + + def step_op(self, model, **kwargs): + """Actual step logic. By default, a dictionary with keys 'train_op', + 'log_op', 'eval_op' and callable values is expected. 'train_op' should + update the model's state as a side-effect, 'log_op' will be logged to + the project's train folder. It should be a dictionary with keys + 'images' and 'scalars'. Images are written as png's, scalars are + written to the log file and stdout. Outputs of 'eval_op' are written + into the project's eval folder to be evaluated with `edeval`.""" + raise NotImplemented() diff --git a/examples/template_pytorch/config.yaml b/examples/template_pytorch/config.yaml new file mode 100644 index 0000000..73e5b34 --- /dev/null +++ b/examples/template_pytorch/config.yaml @@ -0,0 +1,7 @@ +dataset: edflow.datasets.fashionmnist.FashionMNIST +model: template_pytorch.edflow.Model +iterator: template_pytorch.edflow.Iterator +batch_size: 4 +num_epochs: 2 + +n_classes: 10 diff --git a/examples/template_pytorch/edflow.py b/examples/template_pytorch/edflow.py new file mode 100644 index 0000000..cfe4fba --- /dev/null +++ b/examples/template_pytorch/edflow.py @@ -0,0 +1,108 @@ +import functools +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +import numpy as np +from edflow import TemplateIterator, get_logger + + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 4 * 4, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, config["n_classes"]) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(x.shape[0], -1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class Iterator(TemplateIterator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # loss and optimizer + self.criterion = nn.CrossEntropyLoss(reduction="none") + self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + + def save(self, checkpoint_path): + state = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + } + torch.save(state, checkpoint_path) + + def restore(self, checkpoint_path): + state = torch.load(checkpoint_path) + self.model.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + + def step_op(self, model, **kwargs): + # get inputs + inputs, labels = kwargs["image"], kwargs["class"] + inputs = torch.tensor(inputs) + inputs = inputs.transpose(2, 3).transpose(1, 2) + labels = torch.tensor(labels, dtype=torch.long) + + # compute loss + outputs = model(inputs) + loss = self.criterion(outputs, labels) + mean_loss = torch.mean(loss) + + def train_op(): + self.optimizer.zero_grad() + mean_loss.backward() + self.optimizer.step() + + def log_op(): + acc = np.mean( + np.argmax(outputs.detach().numpy(), axis=1) == labels.detach().numpy() + ) + min_loss = np.min(loss.detach().numpy()) + max_loss = np.max(loss.detach().numpy()) + return { + "images": {"inputs": inputs.detach().numpy()}, + "scalars": { + "min_loss": min_loss, + "max_loss": max_loss, + "mean_loss": mean_loss, + "acc": acc, + }, + } + + def eval_op(): + return { + "outputs": np.array(outputs.detach().numpy()), + "loss": np.array(loss.detach().numpy())[:, None], + } + + return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op} + + +def acc_callback(root, data_in, data_out, config): + from tqdm import trange + + logger = get_logger("acc_callback") + correct = 0 + seen = 0 + loss = 0.0 + for i in trange(len(data_in)): + labels = data_in[i]["class"] + outputs = data_out[i]["outputs"] + loss = data_out[i]["loss"].squeeze() + + prediction = np.argmax(outputs, axis=0) + correct += labels == prediction + loss += loss + logger.info("Loss: {}".format(loss / len(data_in))) + logger.info("Accuracy: {}".format(correct / len(data_in))) diff --git a/examples/template_tfe/config.yaml b/examples/template_tfe/config.yaml new file mode 100644 index 0000000..464e988 --- /dev/null +++ b/examples/template_tfe/config.yaml @@ -0,0 +1,7 @@ +dataset: edflow.datasets.fashionmnist.FashionMNIST +model: template_tfe.edflow.Model +iterator: template_tfe.edflow.Iterator +batch_size: 4 +num_epochs: 2 + +n_classes: 10 diff --git a/examples/template_tfe/edflow.py b/examples/template_tfe/edflow.py new file mode 100644 index 0000000..f213221 --- /dev/null +++ b/examples/template_tfe/edflow.py @@ -0,0 +1,102 @@ +import functools +import tensorflow as tf + +tf.enable_eager_execution() +import tensorflow.keras as tfk +import numpy as np +from edflow import TemplateIterator, get_logger + + +class Model(tfk.Model): + def __init__(self, config): + super().__init__() + self.conv1 = tfk.layers.Conv2D(filters=6, kernel_size=5) + self.pool = tfk.layers.MaxPool2D(pool_size=2, strides=2) + self.conv2 = tfk.layers.Conv2D(filters=16, kernel_size=5) + self.fc1 = tfk.layers.Dense(units=120) + self.fc2 = tfk.layers.Dense(units=84) + self.fc3 = tfk.layers.Dense(units=config["n_classes"]) + + input_shape = (config["batch_size"], 28, 28, 1) + self.build(input_shape) + + def call(self, x): + x = self.pool(tf.nn.relu(self.conv1(x))) + x = self.pool(tf.nn.relu(self.conv2(x))) + x = tf.reshape(x, [tf.shape(x)[0], -1]) + x = tf.nn.relu(self.fc1(x)) + x = tf.nn.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class Iterator(TemplateIterator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # loss and optimizer + self.criterion = functools.partial( + tfk.losses.sparse_categorical_crossentropy, from_logits=True + ) + self.optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9) + # to save and restore + self.tfcheckpoint = tf.train.Checkpoint( + model=self.model, optimizer=self.optimizer + ) + + def save(self, checkpoint_path): + self.tfcheckpoint.write(checkpoint_path) + + def restore(self, checkpoint_path): + self.tfcheckpoint.restore(checkpoint_path) + + def step_op(self, model, **kwargs): + # get inputs + inputs, labels = kwargs["image"], kwargs["class"] + + # compute loss + with tf.GradientTape() as tape: + outputs = model(inputs) + loss = self.criterion(y_true=labels, y_pred=outputs) + mean_loss = tf.reduce_mean(loss) + + def train_op(): + grads = tape.gradient(mean_loss, model.trainable_variables) + self.optimizer.apply_gradients(zip(grads, model.trainable_variables)) + + def log_op(): + acc = np.mean(np.argmax(outputs, axis=1) == labels) + min_loss = np.min(loss) + max_loss = np.max(loss) + return { + "images": {"inputs": inputs}, + "scalars": { + "min_loss": min_loss, + "max_loss": max_loss, + "mean_loss": mean_loss, + "acc": acc, + }, + } + + def eval_op(): + return {"outputs": np.array(outputs), "loss": np.array(loss)[:, None]} + + return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op} + + +def acc_callback(root, data_in, data_out, config): + from tqdm import trange + + logger = get_logger("acc_callback") + correct = 0 + seen = 0 + loss = 0.0 + for i in trange(len(data_in)): + labels = data_in[i]["class"] + outputs = data_out[i]["outputs"] + loss = data_out[i]["loss"].squeeze() + + prediction = np.argmax(outputs, axis=0) + correct += labels == prediction + loss += loss + logger.info("Loss: {}".format(loss / len(data_in))) + logger.info("Accuracy: {}".format(correct / len(data_in)))