From 677fcd4d69af4a17f9e7d615acb367547028eaaa Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 17 Apr 2023 16:37:32 -0400 Subject: [PATCH 01/19] Adding demo directory with CycleGAN training --- demo/01_cycleGAN/HowTo_render.md | 0 demo/01_cycleGAN/HowTo_train.md | 59 ++++++++++++++++++++++++++++++++ demo/01_cycleGAN/README.md | 0 3 files changed, 59 insertions(+) create mode 100644 demo/01_cycleGAN/HowTo_render.md create mode 100644 demo/01_cycleGAN/HowTo_train.md create mode 100644 demo/01_cycleGAN/README.md diff --git a/demo/01_cycleGAN/HowTo_render.md b/demo/01_cycleGAN/HowTo_render.md new file mode 100644 index 00000000..e69de29b diff --git a/demo/01_cycleGAN/HowTo_train.md b/demo/01_cycleGAN/HowTo_train.md new file mode 100644 index 00000000..0f6e2c29 --- /dev/null +++ b/demo/01_cycleGAN/HowTo_train.md @@ -0,0 +1,59 @@ +## How to Train a CycleGAN + +This is a basic outline of how to train a CycleGAN using the provided train script. + +### Prerequisites + +1. Python environment with necessary dependencies installed. +2. Image datasets for both source and target domains. +3. `raygun` repository cloned to your local machine. +4. Configuration file (train_conf.json) specifying training parameters. + +### Configuration JSON Parameters +The configuration JSON file contains several parameters that you can modify to customize your CycleGAN training. Here are the key parameters: + +- "framework": Specifies the deep learning framework to use (e.g. "torch", "tensorflow"). +- "system": Specifies the type of system to use for training, so in this case "CycleGAN". +- "job_command": Specifies the job command for running the training script (e.g "bsub", "-n 16", "-gpu "num=1"", "-q gpu_a100"). +- "sources": Specifies the source domains and their corresponding paths, real names, and mask names. +- "common_voxel_size": Specifies the voxel size to cast all data into. +- "ndims": Specifies the number of dimensions for the input data. +- "batch_size": Specifies the batch size for training. +- "num_workers": Specifies the number of workers for data loading. +- "cache_size": Specifies the cache size for data loading. +- "scheduler": Specifies the scheduler type for adjusting learning rate during training. +- "scheduler_kwargs": Specifies the arguments for the scheduler. +- "g_optim_type": Specifies the optimizer type for the generator. +- "g_optim_kwargs": Specifies the arguments for the generator optimizer. +- "d_optim_type": Specifies the optimizer type for the discriminator. +- "d_optim_kwargs": Specifies the arguments for the discriminator optimizer. +- "loss_kwargs": Specifies the arguments for the loss functions. +- "gnet_type": Specifies the type of generator network architecture. +- "gnet_kwargs": Specifies the arguments for the generator network architecture. +- "pretrain_gnet": Specifies whether to pretrain the generator network. +- "dnet_type": Specifies the type of discriminator network architecture. +- "dnet_kwargs": Specifies the arguments for the discriminator network architecture. +- "spawn_subprocess": Specifies whether to spawn subprocesses for training. +- "side_length": Specifies the side length of the input image. +- "num_epochs": Specifies the number of training epochs. +- "log_every": Specifies the frequency of logging during training. +- "save_every": Specifies the frequency of saving models during training. +- "snapshot_every": Specifies the frequency of taking snapshots during training. + +Here's an example of a CycleGan [configuration file]('../../experiments/ieee-isbi-2023/01_cycleGAN/train_conf.json) + +### Training Methods + +#### General training +- From the repository directory, run the following command:`rauygun-train CONFIG_FILE_LOCATION` where `CONGIF_FILE_LOCATION` is the relative path to the JSON configuration file for your training objective. + +#### Batch training +- From the repository directory, run the following command: `rauygun-train-batch CONFIG_FILE_LOCATION` where `CONGIF_FILE_LOCATION` is the relative path to the JSON + +#### Cluster training +- From the repository directory, run the following command: `rauygun-train-cluster CONFIG_FILE_LOCATION` where `CONGIF_FILE_LOCATION` is the relative path to the JSON + + +The CycleGAN training will start and progress will be displayed in the console. + +Once the training is complete, the trained models will be saved in the specified output directory as per the configuration file. diff --git a/demo/01_cycleGAN/README.md b/demo/01_cycleGAN/README.md new file mode 100644 index 00000000..e69de29b From 519025aae3ca31bfe89fdbf65b855dc326b9a8e0 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Wed, 3 May 2023 13:23:56 -0400 Subject: [PATCH 02/19] CycleGAN system docstrings --- notes/naming_conventions.txt | 7 ++ src/raygun/torch/systems/CycleGAN.py | 98 +++++++++++++++++++++------- 2 files changed, 81 insertions(+), 24 deletions(-) create mode 100644 notes/naming_conventions.txt diff --git a/notes/naming_conventions.txt b/notes/naming_conventions.txt new file mode 100644 index 00000000..5d3921b5 --- /dev/null +++ b/notes/naming_conventions.txt @@ -0,0 +1,7 @@ +CycleGAN System: + + - default_config + - config + - common_voxel_size + - ndims + - \ No newline at end of file diff --git a/src/raygun/torch/systems/CycleGAN.py b/src/raygun/torch/systems/CycleGAN.py index ce1ac1e1..2cb40181 100644 --- a/src/raygun/torch/systems/CycleGAN.py +++ b/src/raygun/torch/systems/CycleGAN.py @@ -16,16 +16,25 @@ from raygun.torch.systems import BaseSystem +logger: logging.Logger = logging.Logger(__name__, "INFO") + class CycleGAN(BaseSystem): - def __init__(self, config=None): + """CycleGAN implementation of :class:`raygun.torch.systems.BaseSystem`. + + Args: + config (``string``, optional): + An optional path to CycleGAN configuration parameters for your system. + These work to update/modified the default system parameters in the ``default_cycleGAN_conf.json``. + """ + + def __init__(self, config=None) -> None: super().__init__( default_config="../default_configs/default_cycleGAN_conf.json", config=config, ) - self.logger = logging.Logger(__name__, "INFO") if self.common_voxel_size is None: - self.common_voxel_size = gp.Coordinate( + self.common_voxel_size: gp.Coordinate = gp.Coordinate( daisy.open_ds( self.sources["B"]["path"], self.sources["B"]["name"] ).voxel_size @@ -33,11 +42,24 @@ def __init__(self, config=None): else: self.common_voxel_size = gp.Coordinate(self.common_voxel_size) if self.ndims is None: - self.ndims = sum( + self.ndims: int = sum( np.array(self.common_voxel_size) == np.min(self.common_voxel_size) ) - def batch_show(self, batch=None, i=0, show_mask=False): + def batch_show(self, batch=None, i=0, show_mask=False) -> None: + """Convenience method to display an output batch of a training process. + + Args: + batch (gp.BatchRequest, optional): + Batch of training data to process. Defaults to None. + + i (``int``): + Index of the element in the batch to display. Defaults to 0. + + show_mask (``bool``): + Toggle to add a mask on the displayed batch. Defaults to False. + """ + if batch is None: batch = self.batch if not hasattr(self, "col_dict"): @@ -81,7 +103,14 @@ def batch_show(self, batch=None, i=0, show_mask=False): ) ax.set_title(label) - def write_tBoard_graph(self, batch=None): + def write_tBoard_graph(self, batch=None) -> None: + """Writes the model graph to TensorBoard summary writer. + + Args: + batch (gp.BatchRequest, optional): + Batch of training data. Defaults to None. + """ + if batch is None: batch = self.trainer.batch @@ -105,7 +134,19 @@ def write_tBoard_graph(self, batch=None): except: self.logger.warning("Failed to add model graph to tensorboard.") - def get_extents(self, side_length=None, array_name=None): + def get_extents(self, side_length=None, array_name=None) -> gp.Coordinate: + """Returns the extent of the data in the spatial dimensions. + + Args: + side_length (``int``, optional): + Side length of the array. Defaults to None. + array_name (``string``, optional): + Name of the array. Defaults to None. + + Returns: + gp.Coordinate: + A tuple containing the extent of the data in the spatial dimensions. + """ if side_length is None: side_length = self.side_length @@ -130,24 +171,24 @@ def get_extents(self, side_length=None, array_name=None): ] = side_length # assumes first dimension is z (i.e. the dimension breaking isotropy) return gp.Coordinate(extents) - def setup_networks(self): + def setup_networks(self) -> None: self.netG1 = self.get_network(self.gnet_type, self.gnet_kwargs) self.netG2 = self.get_network(self.gnet_type, self.gnet_kwargs) self.netD1 = self.get_network(self.dnet_type, self.dnet_kwargs) self.netD2 = self.get_network(self.dnet_type, self.dnet_kwargs) - def setup_model(self): + def setup_model(self) -> None: if not hasattr(self, "netG1"): self.setup_networks() if self.sampling_bottleneck: - scale_factor_A = tuple( + scale_factor_A: tuple = tuple( np.divide(self.common_voxel_size, self.A_voxel_size)[-self.ndims :] ) if not any([s < 1 for s in scale_factor_A]): scale_factor_A = None - scale_factor_B = tuple( + scale_factor_B: tuple = tuple( np.divide(self.common_voxel_size, self.B_voxel_size)[-self.ndims :] ) if not any([s < 1 for s in scale_factor_B]): @@ -155,7 +196,7 @@ def setup_model(self): else: scale_factor_A, scale_factor_B = None, None - self.model = CycleModel( + self.model:CycleModel = CycleModel( self.netG1, self.netG2, scale_factor_A, @@ -164,7 +205,7 @@ def setup_model(self): freeze_norms_at=self.freeze_norms_at, ) - def setup_optimization(self): + def setup_optimization(self) -> None: self.optimizer_D = get_base_optimizer(self.d_optim_type)( itertools.chain(self.netD1.parameters(), self.netD2.parameters()), **self.d_optim_kwargs @@ -175,21 +216,21 @@ def setup_optimization(self): scheduler_kwargs = self.scheduler_kwargs else: scheduler = None - scheduler_kwargs = {} + scheduler_kwargs:dict = {} if self.loss_type.lower() == "link": self.optimizer_G = get_base_optimizer(self.g_optim_type)( itertools.chain(self.netG1.parameters(), self.netG2.parameters()), **self.g_optim_kwargs ) - self.optimizer = BaseDummyOptimizer( + self.optimizer: BaseDummyOptimizer = BaseDummyOptimizer( optimizer_G=self.optimizer_G, optimizer_D=self.optimizer_D, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, ) - self.loss = LinkCycleLoss( + self.loss:LinkCycleLoss = LinkCycleLoss( self.netD1, self.netG1, self.netD2, @@ -232,9 +273,9 @@ def setup_optimization(self): "Unexpected Loss Style. Accepted options are 'cycle' or 'split'" ) - def setup_datapipes(self): - self.arrays = {} - self.datapipes = {} + def setup_datapipes(self) -> None: + self.arrays:dict = {} + self.datapipes:dict = {} for id, src in self.sources.items(): self.datapipes[id] = CycleDataPipe( id, @@ -246,16 +287,25 @@ def setup_datapipes(self): ) self.arrays.update(self.datapipes[id].arrays) - def make_request(self, mode: str = "train"): + def make_request(self, mode: str = "train") -> gp.BatchRequest: + """Creates a BatchRequest object for the specified mode. + + Args: + mode (``string``): + The processing mode to create the BatchRequest for. Can be one of "train", "val", "test", or "predict". + + Returns: + A BatchRequest object specifying the desired output arrays and their requested extents and voxel sizes. + """ # create request - request = gp.BatchRequest() + request:gp.BatchRequest = gp.BatchRequest() for array_name, array in self.arrays.items(): if ( mode == "prenet" and ("real" in array_name or "mask" in array_name) ) or ( mode != "prenet" and (mode != "predict" or "cycle" not in array_name) ): - extents = self.get_extents(array_name=array.identifier) + extents: gp.Coordinate = self.get_extents(array_name=array.identifier) request.add( array, self.common_voxel_size * extents, self.common_voxel_size ) @@ -263,7 +313,7 @@ def make_request(self, mode: str = "train"): if __name__ == "__main__": - system = CycleGAN(config="./train_conf.json") + system: CycleGAN = CycleGAN(config="./train_conf.json") system.logger.info("CycleGAN system loaded. Training...") - _ = system.train() + _: None = system.train() system.logger.info("Done training!") From a43a37c9aed45a4e5a545a6c662ed06645c2cb53 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Wed, 3 May 2023 14:29:44 -0400 Subject: [PATCH 03/19] CycleModel docstrings --- src/raygun/torch/models/CycleModel.py | 101 ++++++++++++++++++++------ src/raygun/torch/systems/CycleGAN.py | 5 +- src/raygun/utils.py | 4 +- 3 files changed, 86 insertions(+), 24 deletions(-) diff --git a/src/raygun/torch/models/CycleModel.py b/src/raygun/torch/models/CycleModel.py index 0fb888fc..ddd39eda 100644 --- a/src/raygun/torch/models/CycleModel.py +++ b/src/raygun/torch/models/CycleModel.py @@ -1,43 +1,102 @@ +import torch from raygun.torch.models import FreezableModel from raygun.utils import passing_locals import torch.nn.functional as F -import torch class CycleModel(FreezableModel): + """A CycleGAN model for image-to-image translation. + + Args: + netG1 (``torch.nn.Module``): + A generator network that maps from domain A to domain B. + + netG2 (``torch.nn.Module``): + A generator network that maps from domain B to domain A. + + scale_factor_A (``tuple[int]``, optional): + The downsampling factor for domain A images. + + scale_factor_B (``tuple[int]``, optional): + The downsampling factor for domain B images. + + split (``bool``, optional): + Whether to split the cycle loss into two parts (forward and backward). + + **kwargs: Additional arguments to be passed to the superclass constructor. + """ def __init__( - self, - netG1, - netG2, - scale_factor_A=None, - scale_factor_B=None, - split=False, - **kwargs - ): - output_arrays = ["fake_B", "cycled_B", "fake_A", "cycled_A"] - nets = [netG1, netG2] + self, + netG1, + netG2, + scale_factor_A=None, + scale_factor_B=None, + split=False, + **kwargs + ) -> None: + output_arrays: list[str] = ["fake_B", "cycled_B", "fake_A", "cycled_A"] + nets:list = [netG1, netG2] super().__init__(**passing_locals(locals())) - self.cycle = True - self.crop_pad = None # TODO: Determine if this is depracated + self.cycle:bool = True + self.crop_pad:tuple = None # TODO: Determine if this is depreciated + + def sampling_bottleneck(self, array:torch.Tensor, scale_factor:tuple) -> torch.Tensor: + """Performs sampling bottleneck operation on the input tensor to avoid checkerboard artifacts. - def sampling_bottleneck(self, array, scale_factor): - size = array.shape[-len(scale_factor) :] - mode = {2: "bilinear", 3: "trilinear"}[len(size)] + Args: + array (``torch.Tensor``): + A tensor of shape (batch_size, channels, height, width) or + (batch_size, channels, depth, height, width) depending on the dimensions of the input data. + scale_factor (``tuple[int]``): + A tuple of scale factor for downsampling and upsampling the tensor. + + Returns: + ``torch.Tensor``: + A tensor of the same shape as the input tensor with applied sampling bottleneck operation. + """ + size:torch.Size = array.shape[-len(scale_factor) :] + mode: str = {2: "bilinear", 3: "trilinear"}[len(size)] down = F.interpolate( array, scale_factor=scale_factor, mode=mode, align_corners=True ) return F.interpolate(down, size=size, mode=mode, align_corners=True) - def set_crop_pad(self, crop_pad, ndims): - self.crop_pad = (slice(None, None, None),) * 2 + ( + def set_crop_pad(self, crop_pad:int, ndims:int) -> None: + """Set crop pad for the model. + + Args: + crop_pad (``integer``): + The amount to crop the input by. + ndims (``integer``): + The number of dimensions to apply crop pad. + """ + self.crop_pad:tuple = (slice(None, None, None),) * 2 + ( slice(crop_pad, -crop_pad), ) * ndims - def forward(self, real_A=None, real_B=None): + def forward(self, real_A=None, real_B=None) -> tuple: + """Forward pass of the CycleGAN model. + + Args: + real_A (``torch.Tensor``, optional): + Input tensor for domain A. Default: None. + real_B (``torch.Tensor``, optional): + Input tensor for domain B. Default: None. + + Returns: + Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + A tuple containing: + - fake_B (Tensor or None): The generated output for domain B. None if real_A is None. + - cycled_B (Tensor or None): The generated output for domain A from the fake_B image. None if real_A is None or cycle is False. + - fake_A (Tensor or None): The generated output for domain A. None if real_B is None. + - cycled_A (Tensor or None): The generated output for domain B from the fake_A image. None if real_B is None or cycle is False. + """ + assert ( real_A is not None or real_B is not None ), "Must have some real input to generate outputs)" + if ( real_A is not None ): # allow calling for single direction pass (i.e. prediction) @@ -45,7 +104,7 @@ def forward(self, real_A=None, real_B=None): if self.crop_pad is not None: fake_B = fake_B[self.crop_pad] if self.scale_factor_B: - fake_B = self.sampling_bottleneck( + fake_B: torch.Tensor = self.sampling_bottleneck( fake_B, self.scale_factor_B ) # apply sampling bottleneck if self.cycle: @@ -68,7 +127,7 @@ def forward(self, real_A=None, real_B=None): if self.crop_pad is not None: fake_A = fake_A[self.crop_pad] if self.scale_factor_A: - fake_A = self.sampling_bottleneck( + fake_A: torch.Tensor = self.sampling_bottleneck( fake_A, self.scale_factor_A ) # apply sampling bottleneck if self.cycle: diff --git a/src/raygun/torch/systems/CycleGAN.py b/src/raygun/torch/systems/CycleGAN.py index 2cb40181..e0e19585 100644 --- a/src/raygun/torch/systems/CycleGAN.py +++ b/src/raygun/torch/systems/CycleGAN.py @@ -19,7 +19,10 @@ logger: logging.Logger = logging.Logger(__name__, "INFO") class CycleGAN(BaseSystem): - """CycleGAN implementation of :class:`raygun.torch.systems.BaseSystem`. + """Implementation of a CycleGAN system for image-to-image translation using PyTorch. + + This class extends the `BaseSystem` class and implements the training and inference + pipelines for a CycleGAN model. Args: config (``string``, optional): diff --git a/src/raygun/utils.py b/src/raygun/utils.py index c5c162b8..8dd8f1c9 100644 --- a/src/raygun/utils.py +++ b/src/raygun/utils.py @@ -7,8 +7,8 @@ import gunpowder as gp -def passing_locals(local_dict): - kwargs = {} +def passing_locals(local_dict) -> dict: + kwargs:dict = {} for k, v in local_dict.items(): if k[0] != "_" and k != "self": if k == "kwargs": From 80df2c021ed7611a76043a4a0661d7f96d9981b8 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Wed, 3 May 2023 14:48:36 -0400 Subject: [PATCH 04/19] Utils file docstrings --- src/raygun/utils.py | 120 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 107 insertions(+), 13 deletions(-) diff --git a/src/raygun/utils.py b/src/raygun/utils.py index 8dd8f1c9..b7def66e 100644 --- a/src/raygun/utils.py +++ b/src/raygun/utils.py @@ -7,7 +7,18 @@ import gunpowder as gp -def passing_locals(local_dict) -> dict: +def passing_locals(local_dict:dict) -> dict: + """Extracts the local variables from a given dictionary and returns them as a keyword argument dictionary. + + Args: + local_dict (``dict``): + A dictionary containing local variables. + + Returns: + ``dict``: + A dictionary containing the extracted keyword arguments. + """ + kwargs:dict = {} for k, v in local_dict.items(): if k[0] != "_" and k != "self": @@ -18,10 +29,23 @@ def passing_locals(local_dict) -> dict: return kwargs -def get_config_name(config_path, base_folder): - config_name = os.path.dirname(config_path) - config_name = config_name.replace(base_folder, "") - config_name = "_".join(config_name.split("/"))[1:] +def get_config_name(config_path:str, base_folder:str) -> str: + """Returns a string containing the configuration name given a configuration file path and a base folder. + + Args: + config_path (``string``): + Path of the configuration file. + base_folder (``string``): + Base folder of the configuration file. + + Returns: + ``string``: + Configuration name. + """ + + config_name: str = os.path.dirname(config_path) + config_name: str = config_name.replace(base_folder, "") + config_name: str = "_".join(config_name.split("/"))[1:] return config_name @@ -29,13 +53,35 @@ def get_config_name(config_path, base_folder): def calc_max_padding( output_size, voxel_size, neighborhood=None, sigma=None, mode="shrink" ): + """Calculate the maximum padding for an output size given the voxel size and optional parameters. + + Args: + output_size (Tuple[int, int, int]): + The size of the output. + + voxel_size (Tuple[float, float, float]): + The size of the voxels. + + neighborhood (Tuple[Tuple[int, int, int], ...]], optional): + A tuple of 3x3x3 neighborhood values. + + sigma (``float``, optional): + The sigma value for Gaussian padding. + + mode (``string``, optional): + The mode for snapping the output to the grid. + + Returns: + Tuple[int, int, int]: + The maximum padding for the output size. + """ if neighborhood is not None: if len(neighborhood) > 3: neighborhood = neighborhood[9:12] - max_affinity = gp.Coordinate( + max_affinity: gp.Coordinate = gp.Coordinate( [np.abs(aff) for val in neighborhood for aff in val if aff != 0] ) @@ -43,11 +89,11 @@ def calc_max_padding( if sigma: - method_padding = gp.Coordinate((sigma * 3,) * 3) + method_padding: gp.Coordinate = gp.Coordinate((sigma * 3,) * 3) diag = np.sqrt(output_size[1] ** 2 + output_size[2] ** 2) - max_padding = gp.Roi( + max_padding: gp.Roi = gp.Roi( (gp.Coordinate([i / 2 for i in [output_size[0], diag, diag]]) + method_padding), (0,) * 3, ).snap_to_grid(voxel_size, mode=mode) @@ -56,8 +102,19 @@ def calc_max_padding( def serialize(obj): + """Serialize a Python object into a JSON-compatible format. + + Args: + obj (``dict``, ``np.ndarray``, ``np.int64``, ``obj``, other): + A Python object to be serialized. + + Returns: + (``integer``, ``string``, ``dict``) + The serialized object in a JSON-compatible format. + """ + if isinstance(obj, dict): - out = {} + out: dict = {} for key, value in obj.items(): out[key] = serialize(value) return out @@ -75,20 +132,57 @@ def serialize(obj): return f"#{repr(obj)}#" -def to_json(obj, file, indent=3): +def to_json(obj, file:str, indent=3) -> None: + """Serializes the given object to JSON and writes it to a file. + + Args: + obj (``dict``, ``np.ndarray``, ``np.int64``, ``obj``, other): + The object to serialize to JSON. + + file (``string``): + The name of the file to write the serialized JSON to. + + indent (``integer``): + The number of spaces to use for indentation in the JSON file. + """ + out = serialize(obj) with open(file, "w") as f: - json.dump(out, f, indent=3) + json.dump(out, f, indent=indent) + + +def load_json_file(fin:str) -> dict: + """Loads a JSON file from disk and parses it into a Python dictionary. + Args: + fin (``string``): + The name of the file to load the JSON data from. + + Returns: + ``dict``: + A Python dictionary containing the parsed JSON data. + """ -def load_json_file(fin): with open(fin, "r") as f: config = json.load(StringIO(jsmin(f.read()))) return config -def merge_dicts(from_dict, to_dict): +def merge_dicts(from_dict:dict, to_dict:dict) -> dict: + """Merges two dictionaries together, with keys in the `from_dict` dictionary taking + precedence over keys in the `to_dict` dictionary. + + Args: + from_dict (``dict``): + The dictionary to merge into `to_dict`. + to_dict (``dict``): + The dictionary to merge `from_dict` into. + Returns: + ``dict``: + A new dictionary containing the merged data. + """ + # merge first level for k in from_dict: if k not in to_dict: From ebe8066922ee4c65c60085217b1fefb7a8da8506 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Wed, 3 May 2023 16:50:41 -0400 Subject: [PATCH 05/19] Partial LinkCycleLoss docs --- src/raygun/torch/losses/LinkCycleLoss.py | 87 ++++++++++++++++++++---- 1 file changed, 73 insertions(+), 14 deletions(-) diff --git a/src/raygun/torch/losses/LinkCycleLoss.py b/src/raygun/torch/losses/LinkCycleLoss.py index 6f44a85e..a502933f 100644 --- a/src/raygun/torch/losses/LinkCycleLoss.py +++ b/src/raygun/torch/losses/LinkCycleLoss.py @@ -4,21 +4,80 @@ import logging -logger = logging.Logger(__name__, "INFO") +logger: logging.Logger = logging.Logger(__name__, "INFO") class LinkCycleLoss(BaseCompetentLoss): - """CycleGAN loss function""" + """Linked CycleGAN loss function, implemented in PyTorch. + Args: + netD1 (``nn.Module``): + A discriminator module that differentiates between fake and real ``B``s. + + netG1 (``nn.Module``): + A generator module that turns ``A``s into ``B``s. + + netD2 (``nn.Module``): + A discriminator module that differentiates between fake and real ``A``s. + + netG2 (``nn.Module``): + A generator module that turns ``B``s into ``A``s. + + optimizer_G (``optim.Optimizer``): + An instance of PyTorch optimizer to optimize the generator modules. + + optimizer_D (``optim.Optimizer``): + An instance of PyTorch optimizer to optimize the discriminator modules. + + dims (``int``): + Number of dimensions of image tensor, typically ``2`` for grayscale and ``3`` for RGB. + + l1_loss (``callable``, optional): + A callable loss function. Default is ``torch.nn.SmoothL1Loss()``. + + g_lambda_dict (``dict``, optional): + A dictionary with keys ``A`` and ``B``, each with a dictionary of keys ``l1_loss`` and ``gan_loss``. + The value of ``l1_loss`` is itself a dictionary with keys ``cycled`` and ``identity``, and the value + of ``gan_loss`` is a dictionary with keys ``fake`` and ``cycled``. The values of these keys correspond + to the weights for the corresponding losses. Default is as follows: + ``` + { + "A": { + "l1_loss": {"cycled": 10, "identity": 0}, + "gan_loss": {"fake": 1, "cycled": 0}, + }, + "B": { + "l1_loss": {"cycled": 10, "identity": 0}, + "gan_loss": {"fake": 1, "cycled": 0}, + }, + } + ``` + + d_lambda_dict (``dict``, optional): + A dictionary with keys ``A`` and ``B``, each with a dictionary of keys ``real``, ``fake``, and ``cycled``. + The values of these keys correspond to the weights for the corresponding losses. Default is as follows: + ``` + { + "A": {"real": 1, "fake": 1, "cycled": 0}, + "B": {"real": 1, "fake": 1, "cycled": 0}, + } + ``` + + gan_mode (``str``, optional): + The type of GAN loss to use. Options are ``lsgan`` and ``wgangp``. Default is ``lsgan``. + + **kwargs: + Optional keyword arguments. + """ def __init__( self, - netD1, # differentiates between fake and real Bs - netG1, # turns As into Bs - netD2, # differentiates between fake and real As - netG2, # turns Bs into As - optimizer_G, - optimizer_D, - dims, + netD1:torch.nn.Module, # differentiates between fake and real Bs + netG1:torch.nn.Module, # turns As into Bs + netD2:torch.nn.Module, # differentiates between fake and real As + netG2:torch.nn.Module, # turns Bs into As + optimizer_G:torch.optim.Optimizer, + optimizer_D:torch.optim.Optimizer, + dims:int, l1_loss=torch.nn.SmoothL1Loss(), g_lambda_dict={ "A": { @@ -36,13 +95,13 @@ def __init__( }, gan_mode="lsgan", **kwargs, - ): + ) -> None: super().__init__(**passing_locals(locals())) - self.data_dict = {} + self.data_dict:dict = {} - def backward_D(self, side, dnet, data_dict): + def backward_D(self, side, dnet, data_dict) -> float: """Calculate losses for a discriminator""" - loss = 0 + loss:float = 0. for key, lambda_ in self.d_lambda_dict[side].items(): if lambda_ != 0: # if key == 'identity': # TODO: ADD IDENTITY SUPPORT @@ -58,7 +117,7 @@ def backward_D(self, side, dnet, data_dict): loss.backward() return loss - def backward_Ds(self, data_dict, n_loop=5): + def backward_Ds(self, data_dict, n_loop=5) -> tuple: self.set_requires_grad( [self.netG1, self.netG2], False ) # G does not require gradients when optimizing D From 9b98317ce2ee4a80165f7348fee7de138459c7f5 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Fri, 5 May 2023 14:34:52 -0400 Subject: [PATCH 06/19] LinkCycleLoss documentation and type checking --- src/raygun/torch/losses/LinkCycleLoss.py | 122 +++++++++++++++++++---- 1 file changed, 104 insertions(+), 18 deletions(-) diff --git a/src/raygun/torch/losses/LinkCycleLoss.py b/src/raygun/torch/losses/LinkCycleLoss.py index a502933f..4aa07ce4 100644 --- a/src/raygun/torch/losses/LinkCycleLoss.py +++ b/src/raygun/torch/losses/LinkCycleLoss.py @@ -99,8 +99,24 @@ def __init__( super().__init__(**passing_locals(locals())) self.data_dict:dict = {} - def backward_D(self, side, dnet, data_dict) -> float: - """Calculate losses for a discriminator""" + def backward_D(self, side:str, dnet: torch.nn.Module, data_dict:dict) -> float: + """Calculate losses for a discriminator. + Args: + side (``nn.Module``): + The side of interest of the CycleGAN (A or B type data). + + dnet (``nn.Module``): + The discriminator to backpropagate upon & calculate losses. + + data_dict (``dict``): + The training data dictionary with labels for A/B data. + + Returns: + ``float``: + The calcualted discriminator loss. + + """ + loss:float = 0. for key, lambda_ in self.d_lambda_dict[side].items(): if lambda_ != 0: @@ -117,7 +133,21 @@ def backward_D(self, side, dnet, data_dict) -> float: loss.backward() return loss - def backward_Ds(self, data_dict, n_loop=5) -> tuple: + def backward_Ds(self, data_dict, n_loop=5) -> tuple(float, float): + """Calculate losses for all discriminators in the system. + + Args: + data_dict (``dict``): + The training data dictionary with labels for A/B data. + + n_loop (``integer``): + Number of loop iterations. + + Returns: + ``tuple(float, float)``: + The calcualted losses for each discriminator. + """ + self.set_requires_grad( [self.netG1, self.netG2], False ) # G does not require gradients when optimizing D @@ -126,14 +156,14 @@ def backward_Ds(self, data_dict, n_loop=5) -> tuple: if self.gan_mode.lower() == "wgangp": # Wasserstein Loss for _ in range(n_loop): - loss_D1 = self.backward_D("B", self.netD1, data_dict["B"]) - loss_D2 = self.backward_D("A", self.netD2, data_dict["A"]) + loss_D1: float = self.backward_D("B", self.netD1, data_dict["B"]) + loss_D2: float = self.backward_D("A", self.netD2, data_dict["A"]) self.optimizer_D.step() # update D's weights self.clamp_weights(self.netD1) self.clamp_weights(self.netD2) else: - loss_D1 = self.backward_D("B", self.netD1, data_dict["B"]) - loss_D2 = self.backward_D("A", self.netD2, data_dict["A"]) + loss_D1: float = self.backward_D("B", self.netD1, data_dict["B"]) + loss_D2: float = self.backward_D("A", self.netD2, data_dict["A"]) self.optimizer_D.step() # update D's weights self.set_requires_grad( @@ -143,9 +173,28 @@ def backward_Ds(self, data_dict, n_loop=5) -> tuple: # return losses return loss_D1, loss_D2 - def backward_G(self, side, gnet, dnet, data_dict): - """Calculate losses for a generator""" - loss = 0 + def backward_G(self, side:str, gnet:torch.nn.Module, dnet:torch.nn.Module, data_dict:dict) -> float: + """Calculate losses for a generator. + Args: + side (``nn.Module``): + The side of interest of the CycleGAN (A or B type data). + + gnet (``nn.Module``): + The generator to backpropagate upon & calculate losses. + + dnet (``nn.Module``): + The discriminator to refrence for loss calculations. + + data_dict (``dict``): + The training data dictionary with labels for A/B data. + + Returns: + ``float``: + The calcualted generator loss. + + """ + + loss: float = 0. real = data_dict["real"] for fcn_name, lambdas in self.g_lambda_dict[side].items(): loss_fcn = getattr(self, fcn_name) @@ -172,16 +221,27 @@ def backward_G(self, side, gnet, dnet, data_dict): loss.backward(retain_graph=True) return loss - def backward_Gs(self, data_dict): + def backward_Gs(self, data_dict) -> tuple(float, float): + """Calculate losses for all generators in the system. + + Args: + data_dict (``dict``): + The training data dictionary with labels for A/B data. + + Returns: + ``tuple(float, float)``: + The calcualted losses for each generator. + """ + self.set_requires_grad( [self.netD1, self.netD2], False ) # D requires no gradients when optimizing G self.optimizer_G.zero_grad(set_to_none=True) # set G1's gradients to zero - loss_G1 = self.backward_G( + loss_G1: float = self.backward_G( "B", self.netG1, self.netD1, data_dict["B"] ) # calculate gradient for G - loss_G2 = self.backward_G( + loss_G2: float = self.backward_G( "A", self.netG2, self.netD2, data_dict["A"] ) # calculate gradient for G self.optimizer_G.step() # udpate G1's weights @@ -193,7 +253,33 @@ def backward_Gs(self, data_dict): # return losses return loss_G1, loss_G2 - def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B): + def forward(self, real_A:torch.Tensor, fake_A:torch.Tensor, cycled_A:torch.Tensor, real_B:torch.Tensor, fake_B:torch.Tensor, cycled_B:torch.Tensor) -> float: + """Forward pass for the Linked CycleGAN system. + + Args: + real_A (``torch.Tensor``): + A-style training data. + + fake_A (``torch.Tensor``): + A-style generated data. + + cycled_A (``torch.Tensor``): + A-style recycled generated data. + + real_B (``torch.Tensor``): + B-style training data + + fake_B (``torch.Tensor``): + B-style generated data. + + cycled_B (``torch.Tensor``): + B-style recycled generated data. + + Returns: + ``float``: + The total loss for the system. + """ + self.data_dict.update({ "real_A": real_A, "fake_A": fake_A, @@ -205,10 +291,10 @@ def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B): # crop if necessary if real_A.size()[-self.dims :] != fake_B.size()[-self.dims :]: - real_A = self.crop(real_A, fake_A.size()[-self.dims :]) - real_B = self.crop(real_B, fake_B.size()[-self.dims :]) + real_A: torch.Tensor = self.crop(real_A, fake_A.size()[-self.dims :]) + real_B: torch.Tensor = self.crop(real_B, fake_B.size()[-self.dims :]) - data_dict = { + data_dict: dict = { "A": {"real": real_A, "fake": fake_A, "cycled": cycled_A}, "B": {"real": real_B, "fake": fake_B, "cycled": cycled_B}, } @@ -227,7 +313,7 @@ def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B): } ) - total_loss = loss_G1 + loss_G2 + loss_D1 + loss_D2 + total_loss: float = loss_G1 + loss_G2 + loss_D1 + loss_D2 # define dummy backward pass to disable Gunpowder's Train node loss.backward() call total_loss.backward = lambda: None From b62e6e07a0fee0562b990a14c31fd75e6e4ffa0b Mon Sep 17 00:00:00 2001 From: brianreicher Date: Fri, 5 May 2023 14:48:58 -0400 Subject: [PATCH 07/19] SplitCycleLoss docs and type checking --- src/raygun/torch/losses/LinkCycleLoss.py | 9 +- src/raygun/torch/losses/SplitCycleLoss.py | 192 +++++++++++++++++++--- 2 files changed, 170 insertions(+), 31 deletions(-) diff --git a/src/raygun/torch/losses/LinkCycleLoss.py b/src/raygun/torch/losses/LinkCycleLoss.py index 4aa07ce4..04df112d 100644 --- a/src/raygun/torch/losses/LinkCycleLoss.py +++ b/src/raygun/torch/losses/LinkCycleLoss.py @@ -69,6 +69,7 @@ class LinkCycleLoss(BaseCompetentLoss): **kwargs: Optional keyword arguments. """ + def __init__( self, netD1:torch.nn.Module, # differentiates between fake and real Bs @@ -125,7 +126,7 @@ def backward_D(self, side:str, dnet: torch.nn.Module, data_dict:dict) -> float: # else: # pred = data_dict[key] - this_loss = self.gan_loss(dnet(data_dict[key].detach()), key == "real") + this_loss: float = self.gan_loss(dnet(data_dict[key].detach()), key == "real") self.loss_dict.update({f"Discriminator_{side}/{key}": this_loss}) loss += lambda_ * this_loss @@ -206,13 +207,13 @@ def backward_G(self, side:str, gnet:torch.nn.Module, dnet:torch.nn.Module, data_ if fcn_name == "l1_loss": if real.size()[-self.dims :] != pred.size()[-self.dims :]: - this_loss = loss_fcn( + this_loss: float = loss_fcn( self.crop(real, pred.size()[-self.dims :]), pred ) else: - this_loss = loss_fcn(real, pred) + this_loss: float = loss_fcn(real, pred) elif fcn_name == "gan_loss": - this_loss = loss_fcn(dnet(pred), True) + this_loss: float = loss_fcn(dnet(pred), True) self.loss_dict.update({f"{fcn_name}/{key}_{side}": this_loss}) loss += lambda_ * this_loss diff --git a/src/raygun/torch/losses/SplitCycleLoss.py b/src/raygun/torch/losses/SplitCycleLoss.py index 5c10eef8..dc01f4a7 100644 --- a/src/raygun/torch/losses/SplitCycleLoss.py +++ b/src/raygun/torch/losses/SplitCycleLoss.py @@ -4,11 +4,74 @@ import logging -logger = logging.Logger(__name__, "INFO") +logger: logging.Logger = logging.Logger(__name__, "INFO") class SplitCycleLoss(BaseCompetentLoss): - """CycleGAN loss function""" + """Split CycleGAN loss function, implemented in PyTorch. + + Args: + netD1 (``nn.Module``): + A discriminator module that differentiates between fake and real ``B``s. + + netG1 (``nn.Module``): + A generator module that turns ``A``s into ``B``s. + + netD2 (``nn.Module``): + A discriminator module that differentiates between fake and real ``A``s. + + netG2 (``nn.Module``): + A generator module that turns ``B``s into ``A``s. + + optimizer_G1 (``optim.Optimizer``): + An instance of PyTorch optimizer to optimize the G1 module. + + optimizer_G2 (``optim.Optimizer``): + An instance of PyTorch optimizer to optimize the G2 module. + + optimizer_D (``optim.Optimizer``): + An instance of PyTorch optimizer to optimize the discriminator modules. + + dims (``int``): + Number of dimensions of image tensor, typically ``2`` for grayscale and ``3`` for RGB. + + l1_loss (``callable``, optional): + A callable loss function. Default is ``torch.nn.SmoothL1Loss()``. + + g_lambda_dict (``dict``, optional): + A dictionary with keys ``A`` and ``B``, each with a dictionary of keys ``l1_loss`` and ``gan_loss``. + The value of ``l1_loss`` is itself a dictionary with keys ``cycled`` and ``identity``, and the value + of ``gan_loss`` is a dictionary with keys ``fake`` and ``cycled``. The values of these keys correspond + to the weights for the corresponding losses. Default is as follows: + ``` + { + "A": { + "l1_loss": {"cycled": 10, "identity": 0}, + "gan_loss": {"fake": 1, "cycled": 0}, + }, + "B": { + "l1_loss": {"cycled": 10, "identity": 0}, + "gan_loss": {"fake": 1, "cycled": 0}, + }, + } + ``` + + d_lambda_dict (``dict``, optional): + A dictionary with keys ``A`` and ``B``, each with a dictionary of keys ``real``, ``fake``, and ``cycled``. + The values of these keys correspond to the weights for the corresponding losses. Default is as follows: + ``` + { + "A": {"real": 1, "fake": 1, "cycled": 0}, + "B": {"real": 1, "fake": 1, "cycled": 0}, + } + ``` + + gan_mode (``str``, optional): + The type of GAN loss to use. Options are ``lsgan`` and ``wgangp``. Default is ``lsgan``. + + **kwargs: + Optional keyword arguments. + """ def __init__( self, @@ -37,13 +100,29 @@ def __init__( }, gan_mode="lsgan", **kwargs, - ): + ) -> None: super().__init__(**passing_locals(locals())) - self.data_dict = {} + self.data_dict: dict = {} + + def backward_D(self, side:str, dnet:torch.nn.Module, data_dict:dict) -> float: + """Calculate losses for a discriminator. + Args: + side (``nn.Module``): + The side of interest of the CycleGAN (A or B type data). + + dnet (``nn.Module``): + The discriminator to backpropagate upon & calculate losses. + + data_dict (``dict``): + The training data dictionary with labels for A/B data. - def backward_D(self, side, dnet, data_dict): - """Calculate losses for a discriminator""" - loss = 0 + Returns: + ``float``: + The calcualted discriminator loss. + + """ + + loss: float = 0. for key, lambda_ in self.d_lambda_dict[side].items(): if lambda_ != 0: # if key == 'identity': # TODO: ADD IDENTITY SUPPORT @@ -51,7 +130,7 @@ def backward_D(self, side, dnet, data_dict): # else: # pred = data_dict[key] - this_loss = self.gan_loss(dnet(data_dict[key].detach()), key == "real") + this_loss: float = self.gan_loss(dnet(data_dict[key].detach()), key == "real") self.loss_dict.update({f"Discriminator_{side}/{key}": this_loss}) loss += lambda_ * this_loss @@ -59,7 +138,21 @@ def backward_D(self, side, dnet, data_dict): loss.backward() return loss - def backward_Ds(self, data_dict, n_loop=5): + def backward_Ds(self, data_dict:dict, n_loop=5) -> tuple(float, float): + """Calculate losses for all discriminators in the system. + + Args: + data_dict (``dict``): + The training data dictionary with labels for A/B data. + + n_loop (``integer``): + Number of loop iterations. + + Returns: + ``tuple(float, float)``: + The calcualted losses for each discriminator. + """ + self.set_requires_grad( [self.netG1, self.netG2], False ) # G does not require gradients when optimizing D @@ -67,14 +160,14 @@ def backward_Ds(self, data_dict, n_loop=5): if self.gan_mode.lower() == "wgangp": # Wasserstein Loss for _ in range(n_loop): - loss_D1 = self.backward_D("B", self.netD1, data_dict["B"]) - loss_D2 = self.backward_D("A", self.netD2, data_dict["A"]) + loss_D1: float = self.backward_D("B", self.netD1, data_dict["B"]) + loss_D2: float = self.backward_D("A", self.netD2, data_dict["A"]) self.optimizer_D.step() # update D's weights self.clamp_weights(self.netD1) self.clamp_weights(self.netD2) else: - loss_D1 = self.backward_D("B", self.netD1, data_dict["B"]) - loss_D2 = self.backward_D("A", self.netD2, data_dict["A"]) + loss_D1: float = self.backward_D("B", self.netD1, data_dict["B"]) + loss_D2: float = self.backward_D("A", self.netD2, data_dict["A"]) self.optimizer_D.step() # update D's weights self.set_requires_grad( @@ -84,9 +177,28 @@ def backward_Ds(self, data_dict, n_loop=5): # return losses return loss_D1, loss_D2 - def backward_G(self, side, gnet, dnet, data_dict): - """Calculate losses for a generator""" - loss = 0 + def backward_G(self, side:str, gnet:torch.nn.Module, dnet:torch.nn.Module, data_dict:dict) -> float: + """Calculate losses for a generator. + Args: + side (``nn.Module``): + The side of interest of the CycleGAN (A or B type data). + + gnet (``nn.Module``): + The generator to backpropagate upon & calculate losses. + + dnet (``nn.Module``): + The discriminator to refrence for loss calculations. + + data_dict (``dict``): + The training data dictionary with labels for A/B data. + + Returns: + ``float``: + The calcualted generator loss. + + """ + + loss: float = 0. real = data_dict["real"] for fcn_name, lambdas in self.g_lambda_dict[side].items(): loss_fcn = getattr(self, fcn_name) @@ -98,13 +210,13 @@ def backward_G(self, side, gnet, dnet, data_dict): if fcn_name == "l1_loss": if real.size()[-self.dims :] != pred.size()[-self.dims :]: - this_loss = loss_fcn( + this_loss: float = loss_fcn( self.crop(real, pred.size()[-self.dims :]), pred ) else: - this_loss = loss_fcn(real, pred) + this_loss: float = loss_fcn(real, pred) elif fcn_name == "gan_loss": - this_loss = loss_fcn(dnet(pred), True) + this_loss: float = loss_fcn(dnet(pred), True) self.loss_dict.update({f"{fcn_name}/{key}_{side}": this_loss}) loss += lambda_ * this_loss @@ -113,7 +225,7 @@ def backward_G(self, side, gnet, dnet, data_dict): loss.backward() return loss - def backward_Gs(self, data_dict): + def backward_Gs(self, data_dict:dict) -> tuple(float, float): self.set_requires_grad( [self.netD1, self.netD2], False ) # D requires no gradients when optimizing G @@ -126,7 +238,7 @@ def backward_Gs(self, data_dict): [self.netG2], False ) # G2 requires no gradients when optimizing G1 self.optimizer_G1.zero_grad(set_to_none=True) # set G1's gradients to zero - loss_G1 = self.backward_G( + loss_G1: float = self.backward_G( "B", self.netG1, self.netD1, data_dict["B"] ) # calculate gradient for G self.optimizer_G1.step() # udpate G1's weights @@ -139,7 +251,7 @@ def backward_Gs(self, data_dict): [self.netG1], False ) # G1 requires no gradients when optimizing G2 self.optimizer_G2.zero_grad(set_to_none=True) # set G2's gradients to zero - loss_G2 = self.backward_G( + loss_G2: float = self.backward_G( "A", self.netG2, self.netD2, data_dict["A"] ) # calculate gradient for G self.optimizer_G2.step() # udpate G2's weights @@ -150,7 +262,33 @@ def backward_Gs(self, data_dict): # return losses return loss_G1, loss_G2 - def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B): + def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B) -> float: + """Forward pass for the Split CycleGAN system. + + Args: + real_A (``torch.Tensor``): + A-style training data. + + fake_A (``torch.Tensor``): + A-style generated data. + + cycled_A (``torch.Tensor``): + A-style recycled generated data. + + real_B (``torch.Tensor``): + B-style training data + + fake_B (``torch.Tensor``): + B-style generated data. + + cycled_B (``torch.Tensor``): + B-style recycled generated data. + + Returns: + ``float``: + The total loss for the system. + """ + self.data_dict.update( { "real_A": real_A, @@ -164,10 +302,10 @@ def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B): # crop if necessary if real_A.size()[-self.dims :] != fake_B.size()[-self.dims :]: - real_A = self.crop(real_A, fake_A.size()[-self.dims :]) - real_B = self.crop(real_B, fake_B.size()[-self.dims :]) + real_A: torch.Tensor = self.crop(real_A, fake_A.size()[-self.dims :]) + real_B: torch.Tensor = self.crop(real_B, fake_B.size()[-self.dims :]) - data_dict = { + data_dict: dict = { "A": {"real": real_A, "fake": fake_A, "cycled": cycled_A}, "B": {"real": real_B, "fake": fake_B, "cycled": cycled_B}, } @@ -186,7 +324,7 @@ def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B): } ) - total_loss = loss_G1 + loss_G2 + loss_D1 + loss_D2 + total_loss: float = loss_G1 + loss_G2 + loss_D1 + loss_D2 # define dummy backward pass to disable Gunpowder's Train node loss.backward() call total_loss.backward = lambda: None From a3d0ade9120012829af176d822daca709e930aad Mon Sep 17 00:00:00 2001 From: brianreicher Date: Fri, 5 May 2023 15:13:37 -0400 Subject: [PATCH 08/19] BaseCompetentLoss docs and type checking (initial) --- src/raygun/torch/losses/BaseCompetentLoss.py | 100 ++++++++++++++----- 1 file changed, 74 insertions(+), 26 deletions(-) diff --git a/src/raygun/torch/losses/BaseCompetentLoss.py b/src/raygun/torch/losses/BaseCompetentLoss.py index 47e60ef7..a4ead5ff 100644 --- a/src/raygun/torch/losses/BaseCompetentLoss.py +++ b/src/raygun/torch/losses/BaseCompetentLoss.py @@ -6,48 +6,96 @@ class BaseCompetentLoss(torch.nn.Module): - def __init__(self, **kwargs): - super().__init__() - kwargs = passing_locals(locals()) - for key, value in kwargs.items(): - setattr(self, key, value) - - if hasattr(self, "gan_mode"): - self.gan_loss = GANLoss(gan_mode=self.gan_mode) - - self.loss_dict = {} - - def set_requires_grad(self, nets, requires_grad=False): - """Set requies_grad=False for all the networks to avoid unnecessary computations - Parameters: - nets (network list) -- a list of networks - requires_grad (bool) -- whether the networks require gradients or not + """Base loss function, implemented in PyTorch. + + Args: + **kwargs: + Optional keyword arguments. + """ + + def __init__(self, **kwargs) -> None: + super().__init__() + kwargs: dict = passing_locals(locals()) + for key, value in kwargs.items(): + setattr(self, key, value) + + if hasattr(self, "gan_mode"): + self.gan_loss: GANLoss = GANLoss(gan_mode=self.gan_mode) + + self.loss_dict: dict = {} + + def set_requires_grad(self, nets:list, requires_grad=False) -> None: + """Sets requies_grad=False for all the networks to avoid unnecessary computations. + + Args: + nets (``list[torch.nn.Module, ...]``): + A list of networks. + + requires_grad (``bool``): + Whether the networks require gradients or not. """ - if not isinstance(nets, list): - nets = [nets] + + if not isinstance(nets, list): # TODO: remove, this should be redudant with type checking enforced + nets: list = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad - def crop(self, x, shape): - """Center-crop x to match spatial dimensions given by shape.""" + def crop(self, x:torch.Tensor, shape:tuple) -> torch.Tensor: + """Center-crop x to match spatial dimensions given by shape. + + Args: + x (``torch.Tensor``): + The tensor to center-crop. + + shape (``tuple``): + The shape to match the crop to. + + Returns: + ``torch.Tensor``: + The center-cropped tensor to the spatial dimensions given. + """ - x_target_size = x.size()[: -self.dims] + shape + x_target_size:tuple = x.size()[: -self.dims] + shape - offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size)) + offset: tuple = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size)) - slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size)) + slices: tuple = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size)) return x[slices] - def clamp_weights(self, net, min=-0.01, max=0.01): + def clamp_weights(self, net:torch.nn.Module, min=-0.01, max=0.01) -> None: + """Clamp the weights of a given network. + + Args: + net (``torch.nn.Module``): + The network to clamp. + + min (``float``, optional): + The minimum value to clamp network weights to. + + max (``float``, optional): + The maximum value to clamp network weights to. + """ + for module in net.model: if hasattr(module, "weight") and hasattr(module.weight, "data"): temp = module.weight.data module.weight.data = temp.clamp(min, max) - def add_log(self, writer, step): + def add_log(self, writer, step) -> None: + """Add an additional log to the writer, containing loss values and image examples. + + Args: + writer (``TODO``): + The display writer to append the losses & images to. + + step (``int``): + TODO. + + """ + # add loss values for key, loss in self.loss_dict.items(): writer.add_scalar(key, loss, step) @@ -69,7 +117,7 @@ def add_log(self, writer, step): img = (img * 0.5) + 0.5 writer.add_image(name, img, global_step=step, dataformats="HW") - def update_status(self, step): + def update_status(self, step) -> None: if hasattr(self, "validation_config") and ( step % self.validation_config["validate_every"] == 0 ): From 975b7b149df1bc9853b52b4578a29fc1879b50f2 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 8 May 2023 13:31:20 -0400 Subject: [PATCH 09/19] NLayerDiscriminator docs & type checking --- .../torch/networks/NLayerDiscriminator.py | 157 +++++++++++------- 1 file changed, 100 insertions(+), 57 deletions(-) diff --git a/src/raygun/torch/networks/NLayerDiscriminator.py b/src/raygun/torch/networks/NLayerDiscriminator.py index 38ab54b3..a6a0f6cb 100644 --- a/src/raygun/torch/networks/NLayerDiscriminator.py +++ b/src/raygun/torch/networks/NLayerDiscriminator.py @@ -3,7 +3,28 @@ class NLayerDiscriminator2D(torch.nn.Module): - """Defines a PatchGAN discriminator""" + """Defines a 2D PatchGAN discriminator. + + Args: + + input_nc (``integer``, optional): + Number of channels in the input images, with a default of 1. + + ngf (``integer``, optional): + Number of filters in the last convolutional layer, with a default of 64. + + n_layers (``integer``, optional): + Number of convolution layers in the discriminator, with a default of 3. + + norm_layer (callable, optional): + Normalization layer to use, with the default being 2-dimensional batch normalization. + + kw (``integer``, optional): + Kernel size for convolutional layers, with the default being 4. + + downsampling_kw (optional): + Kernel size for downsampling convolutional layers. If not provided, defaults to the same value as kw. + """ def __init__( self, @@ -13,33 +34,27 @@ def __init__( norm_layer=torch.nn.BatchNorm2d, kw=4, downsampling_kw=None, - ): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ngf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ + ) -> None: + super().__init__() if ( type(norm_layer) == functools.partial ): # no need to use bias as BatchNorm2d has affine parameters - use_bias = norm_layer.func == torch.nn.InstanceNorm2d + use_bias: bool = norm_layer.func == torch.nn.InstanceNorm2d else: - use_bias = norm_layer == torch.nn.InstanceNorm2d + use_bias: bool = norm_layer == torch.nn.InstanceNorm2d if downsampling_kw is None: - downsampling_kw = kw + downsampling_kw: int = kw - padw = 1 - ds_kw = downsampling_kw - sequence = [ + padw:int = 1 + ds_kw:int = downsampling_kw + sequence: list = [ torch.nn.Conv2d(input_nc, ngf, kernel_size=ds_kw, stride=2, padding=padw), torch.nn.LeakyReLU(0.2, True), ] - nf_mult = 1 - nf_mult_prev = 1 + nf_mult: int = 1 + nf_mult_prev: int = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2**n, 8) @@ -74,36 +89,61 @@ def __init__( sequence += [ torch.nn.Conv2d(ngf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map - self.model = torch.nn.Sequential(*sequence) + self.model: torch.nn.Sequential = torch.nn.Sequential(*sequence) @property - def FOV(self): + def FOV(self) -> float: # Returns the receptive field of one output neuron for a network (written for patch discriminators) # See https://distill.pub/2019/computing-receptive-fields/#solving-receptive-field-region for formula derivation - L = 0 # num of layers - k = [] # [kernel width at layer l] - s = [] # [stride at layer i] + L: int = 0 # num of layers + k: list = [] # [kernel width at layer l] + s: list = [] # [stride at layer i] for layer in self.model: if hasattr(layer, "kernel_size"): L += 1 k += [layer.kernel_size[-1]] s += [layer.stride[-1]] - r = 1 + r: float = 1. for l in range(L - 1, 0, -1): - r = s[l] * r + (k[l] - s[l]) + r: float = s[l] * r + (k[l] - s[l]) return r - def forward(self, input): - """Standard forward.""" + def forward(self, input:torch.Tensor) -> torch.Tensor: + """Standard forward discriminator pass. + + Args: + + input (``torch.Tensor``): + Image tensor to pass through the network. + """ return self.model(input) class NLayerDiscriminator3D(torch.nn.Module): - """Defines a PatchGAN discriminator""" - + """Defines a 3D PatchGAN discriminator. + + Args: + input_nc (``integer``, optional): + Number of channels in the input images, with a default of 1. + + ngf (``integer``, optional): + Number of filters in the last convolutional layer, with a default of 64. + + n_layers (``integer``, optional): + Number of convolution layers in the discriminator, with a default of 3. + + norm_layer (callable, optional): + Normalization layer to use, with the default being 3-dimensional batch normalization. + + kw (``integer``, optional): + Kernel size for convolutional layers, with the default being 4. + + downsampling_kw (optional): + Kernel size for downsampling convolutional layers. If not provided, defaults to the same value as kw. + """ def __init__( self, input_nc=1, @@ -112,33 +152,27 @@ def __init__( norm_layer=torch.nn.BatchNorm3d, kw=4, downsampling_kw=None, - ): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ngf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ + ) -> None: + super().__init__() if ( type(norm_layer) == functools.partial ): # no need to use bias as BatchNorm3d has affine parameters - use_bias = norm_layer.func == torch.nn.InstanceNorm3d + use_bias: bool = norm_layer.func == torch.nn.InstanceNorm3d else: - use_bias = norm_layer == torch.nn.InstanceNorm3d + use_bias: bool = norm_layer == torch.nn.InstanceNorm3d if downsampling_kw is None: - downsampling_kw = kw + downsampling_kw: int = kw - padw = 1 - ds_kw = downsampling_kw - sequence = [ + padw: int = 1 + ds_kw: int = downsampling_kw + sequence: list = [ torch.nn.Conv3d(input_nc, ngf, kernel_size=ds_kw, stride=2, padding=padw), torch.nn.LeakyReLU(0.2, True), ] - nf_mult = 1 - nf_mult_prev = 1 + nf_mult: int = 1 + nf_mult_prev: int = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2**n, 8) @@ -173,24 +207,33 @@ def __init__( sequence += [ torch.nn.Conv3d(ngf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map - self.model = torch.nn.Sequential(*sequence) + self.model: torch.nn.Sequential = torch.nn.Sequential(*sequence) + + def forward(self, input:torch.Tensor) -> torch.Tensor: + """Standard forward discriminator pass. + + Args: + + input (``torch.Tensor``): + Image tensor to pass through the network. + """ - def forward(self, input): - """Standard forward.""" return self.model(input) class NLayerDiscriminator(NLayerDiscriminator2D, NLayerDiscriminator3D): - """Defines a PatchGAN discriminator""" - - def __init__(self, ndims, **kwargs): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ngf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ + """Interface for a PatchGAN discriminator. + + Args: + ndims (``integer``): + Number of image dimensions, to create a 2D or 3D PatchGAN discriminator. + + **kwargs: + Optional keyword arguments defining network hyper parameters. + """ + + def __init__(self, ndims:int, **kwargs) -> None: + if ndims == 2: NLayerDiscriminator2D.__init__(self, **kwargs) elif ndims == 3: From c76edd3cf8c378b35fdd45161cb430c45e298740 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 8 May 2023 14:18:05 -0400 Subject: [PATCH 10/19] GAN Loss docs & type checking --- src/raygun/torch/losses/GANLoss.py | 66 ++++++++++++++++++------------ 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/src/raygun/torch/losses/GANLoss.py b/src/raygun/torch/losses/GANLoss.py index afa7eabc..e57d3f16 100644 --- a/src/raygun/torch/losses/GANLoss.py +++ b/src/raygun/torch/losses/GANLoss.py @@ -3,40 +3,49 @@ class GANLoss(torch.nn.Module): - """Define different GAN objectives. - The GANLoss class abstracts away the need to create the target label tensor - that has the same size as the input. - """ + """Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input. + + Args: + gan_mode (``string``): + The type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + + target_real_label (``float``, optional): + Label for a real image, with a default of 1.0. + + target_fake_label (``float``, optional): + Label for a fake image fake image, with a default of 0.0. - def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): - """Initialize the GANLoss class. - Parameters: - gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. - target_real_label (bool) - - label for a real image - target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. - """ + """ + + def __init__(self, gan_mode:str, target_real_label=1.0, target_fake_label=0.0) -> None: super(GANLoss, self).__init__() self.register_buffer("real_label", torch.tensor(target_real_label)) self.register_buffer("fake_label", torch.tensor(target_fake_label)) - self.gan_mode = gan_mode + self.gan_mode: str = gan_mode if gan_mode == "lsgan": - self.loss = torch.nn.MSELoss() + self.loss: torch.nn.MSELoss = torch.nn.MSELoss() elif gan_mode == "vanilla": - self.loss = torch.nn.BCEWithLogitsLoss() + self.loss: torch.nn.BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss() elif gan_mode in ["wgangp"]: self.loss = None else: raise NotImplementedError("gan mode %s not implemented" % gan_mode) - def get_target_tensor(self, prediction, target_is_real): + def get_target_tensor(self, prediction:torch.Tensor, target_is_real:bool) -> torch.Tensor: """Create label tensors with the same size as the input. - Parameters: - prediction (tensor) - - typically the prediction from a discriminator - target_is_real (bool) - - if the ground truth label is for real images or fake images + + Args: + prediction (``torch.Tensor``): + Typically the prediction from a discriminator. + + target_is_real (``bool``): + Boolean to determine the ground truth label is for real images or fake images. + Returns: - A label tensor filled with ground truth label, and with the size of the input + ``torch.Tensor``: + A label tensor filled with ground truth label, and with the size of the input """ if target_is_real: @@ -45,17 +54,22 @@ def get_target_tensor(self, prediction, target_is_real): target_tensor = self.fake_label return target_tensor.expand_as(prediction) - def __call__(self, prediction, target_is_real): + def __call__(self, prediction:torch.Tensor, target_is_real:bool) -> float: """Calculate loss given Discriminator's output and grount truth labels. - Parameters: - prediction (tensor) - - typically the prediction output from a discriminator - target_is_real (bool) - - if the ground truth label is for real images or fake images + Args: + prediction (``torch.Tensor``): + Typically the prediction output from a discriminator. + + target_is_real (``bool``): + Boolean to determine the ground truth label is for real images or fake images. + Returns: - the calculated loss. + ``float``: + The calculated loss. """ if self.gan_mode in ["lsgan", "vanilla"]: - target_tensor = self.get_target_tensor(prediction, target_is_real) - loss = self.loss(prediction, target_tensor) + target_tensor: torch.Tensor = self.get_target_tensor(prediction, target_is_real) + loss: float = self.loss(prediction, target_tensor) elif self.gan_mode == "wgangp": if target_is_real: loss = -prediction.mean() From 1a7df3c9b8ff3757654867c1517e02acb96d0ad5 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 8 May 2023 14:51:21 -0400 Subject: [PATCH 11/19] Freezeable Model docs and type checking --- src/raygun/torch/models/BaseModel.py | 4 +- src/raygun/torch/models/FreezableModel.py | 50 +++++++++++++++++++---- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/raygun/torch/models/BaseModel.py b/src/raygun/torch/models/BaseModel.py index c408ccff..9f882dc6 100644 --- a/src/raygun/torch/models/BaseModel.py +++ b/src/raygun/torch/models/BaseModel.py @@ -11,8 +11,8 @@ def __init__(self, **kwargs) -> None: self, "output_arrays" ), "Model object must have list attribute `output_arrays` indicating what arrays are output by the model's forward pass, in order." - def add_log(self, writer, iter): + def add_log(self, writer, iter) -> None: pass - def forward(self): + def forward(self) -> None: return diff --git a/src/raygun/torch/models/FreezableModel.py b/src/raygun/torch/models/FreezableModel.py index 9c511957..5a67524e 100644 --- a/src/raygun/torch/models/FreezableModel.py +++ b/src/raygun/torch/models/FreezableModel.py @@ -1,21 +1,51 @@ import torch +from torch.utils.tensorboard import SummaryWriter from raygun.torch.models import BaseModel from raygun.torch.networks.utils import * from raygun.utils import passing_locals class FreezableModel(BaseModel): - def __init__(self, freeze_norms_at=None, **kwargs): + """A base model class for torch models that can freeze normalization layers during training. + + Args: + freeze_norms_at (``integer``): If set, normalization layers will be frozen + after the given training step. Defaults to None. + + **kwargs: + Additional arguments to pass to the parent class. + + """ + + def __init__(self, freeze_norms_at=None, **kwargs) -> None: super().__init__(**passing_locals(locals())) - def set_norm_modes(self, mode: str = "train"): + def set_norm_modes(self, mode:str = "train") -> None: + """Set the mode for all normalization layers in the model. + + Args: + mode (``string``): + The mode to set normalization layers to. Must be either "train" or "fix_norms". + """ + for net in self.nets: set_norm_mode(net, mode) - def add_log(self, writer, step): - means = [] - vars = [] + def add_log(self, writer:SummaryWriter, step:int) -> None: + """Add histogram of the means and variances of the model's normalization layers to the + given Tensorboard writer. + + Args: + writer (``torch.utils.tensorboard.SummaryWriter``): + The Tensorboard writer. + + step (``integer``): + The current training step. + """ + + means: list = [] + vars: list = [] for net in self.nets: mean, var = get_running_norm_stats(net) if mean is not None: @@ -23,10 +53,16 @@ def add_log(self, writer, step): vars.append(var) if len(means) > 0: - hists = {"means": torch.cat(means), "vars": torch.cat(vars)} + hists: dict[str: torch.Tensor] = {"means": torch.cat(means), "vars": torch.cat(vars)} for tag, values in hists.items(): writer.add_histogram(tag, values, global_step=step) - def update_status(self, step): + def update_status(self, step:int) -> None: + """Update the status of the model to freeze the normalization layers, if applicable. + + Args: + step (``integer``): + The current training step. + """ if self.freeze_norms_at is not None and step >= self.freeze_norms_at: self.set_norm_modes(mode="fix_norms") From 7cf09161aa516e3eae3d3f5c77e6a54129bd353d Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 8 May 2023 14:55:03 -0400 Subject: [PATCH 12/19] BaseModel docs and type checking --- src/raygun/torch/models/BaseModel.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/raygun/torch/models/BaseModel.py b/src/raygun/torch/models/BaseModel.py index 9f882dc6..44346d3c 100644 --- a/src/raygun/torch/models/BaseModel.py +++ b/src/raygun/torch/models/BaseModel.py @@ -1,7 +1,14 @@ import torch - +from torch.utils.tensorboard import SummaryWriter class BaseModel(torch.nn.Module): + """Base class for bulding a PyTorch model. + + Args: + **kwargs: + Optional keyword arguments. + """ + def __init__(self, **kwargs) -> None: super().__init__() for key, value in kwargs.items(): @@ -11,8 +18,10 @@ def __init__(self, **kwargs) -> None: self, "output_arrays" ), "Model object must have list attribute `output_arrays` indicating what arrays are output by the model's forward pass, in order." - def add_log(self, writer, iter) -> None: + def add_log(self, writer:SummaryWriter, iter:int) -> None: + """Dummy model log add.""" pass def forward(self) -> None: + """Dummy forward pass.""" return From 1f13881e0638a64124f8f97b4b5252942104c59b Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 8 May 2023 15:10:20 -0400 Subject: [PATCH 13/19] Network utils docs and type checking --- src/raygun/torch/networks/utils.py | 84 +++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 20 deletions(-) diff --git a/src/raygun/torch/networks/utils.py b/src/raygun/torch/networks/utils.py index 86cd95e2..928ece85 100644 --- a/src/raygun/torch/networks/utils.py +++ b/src/raygun/torch/networks/utils.py @@ -4,16 +4,39 @@ from torch.nn import init -def get_norm_layers(net): +def get_norm_layers(net:torch.nn.Module) -> list: + """Get a list of all normalization layers in the given module. + + Args: + net (``torch.nn.Module``): + The module to search for normalization layers. + + Returns: + ``list``: + A list of all normalization layers found in the module. + """ + return [n for n in net.modules() if "norm" in type(n).__name__.lower()] -def get_running_norm_stats(net): - means = [] - vars = [] +def get_running_norm_stats(net:torch.nn.Module) -> tuple: + """Get the running means and variances of all normalization layers in the given module. + + Args: + net (``torch.nn.Module``): + The module to compute running means and variances for. + + Returns: + ``tuple``: + A tuple containing two tensors, the concatenated running means and running variances, respectively. + If no normalization layer is found, returns None for both elements. + """ + + means: list = [] + vars: list = [] try: - norms = get_norm_layers(net) + norms: list = get_norm_layers(net) for norm in norms: means.append(norm.running_mean) @@ -28,7 +51,17 @@ def get_running_norm_stats(net): return means, vars -def set_norm_mode(net, mode="train"): +def set_norm_mode(net:torch.nn.Module, mode="train") -> None: + """Set the normalization mode of the given module. + + Args: + net (``torch.nn.Module``): + The module to set the normalization mode for. + + mode (``string``): + The normalization mode to set. Can be one of "train", "eval", or "fix_norms". + """ + if mode == "fix_norms": net.train() for m in net.modules(): @@ -42,17 +75,24 @@ def set_norm_mode(net, mode="train"): net.eval() -def init_weights(net, init_type="normal", init_gain=0.02, nonlinearity="relu"): +def init_weights(net, init_type="normal", init_gain=0.02, nonlinearity="relu") -> None: """Initialize network weights. - Parameters: - net (network) -- network to be initialized - init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + + Args: + net (``torch.nn.Module``): + The network to be initialized. + + init_type (``string``): + The name of an initialization method: normal | xavier | kaiming | orthogonal. + + init_gain (``float``): + Scaling factor for normal, xavier and orthogonal. + + 'Normal' is used in the original pix2pix and CycleGAN paper. Xavier and kaiming might work better for some applications. Feel free to try yourself. """ - def init_func(m): # define the initialization function + def init_func(m) -> None: # define the initialization function classname = m.__class__.__name__ if hasattr(m, "weight") and ( classname.find("Conv") != -1 or classname.find("Linear") != -1 @@ -84,22 +124,26 @@ def init_func(m): # define the initialization function class NoiseBlock(torch.nn.Module): """Definies a block for producing and appending a feature map of gaussian noise with mean=0 and stdev=1""" - def __init__(self): + def __init__(self) -> None: super().__init__() - def forward(self, x): - shape = list(x.shape) + def forward(self, x:torch.Tensor) -> torch.Tensor: + """Standard NoiseBlock forward pass on a data tensor.""" + + shape: list = list(x.shape) shape[1] = 1 # only make one noise feature - noise = torch.empty(shape, device=x.device).normal_() + noise: torch.Tensor = torch.empty(shape, device=x.device).normal_() return torch.cat([x, noise.requires_grad_()], 1) class ParameterizedNoiseBlock(torch.nn.Module): """Definies a block for producing and appending a feature map of gaussian noise with mean and stdev defined by the first two feature maps of the incoming tensor""" - def __init__(self): + def __init__(self) -> None: super().__init__() - def forward(self, x): - noise = torch.normal(x[:, 0, ...], torch.relu(x[:, 1, ...])).unsqueeze(1) + def forward(self, x:torch.Tensor) -> torch.Tensor: + """Standard ParameterizedNoiseBlock forward pass on a data tensor.""" + + noise: torch.Tensor = torch.normal(x[:, 0, ...], torch.relu(x[:, 1, ...])).unsqueeze(1) return torch.cat([x, noise.requires_grad_()], 1) From 7b2fdb4321c07c391b6d5dc8d7735501155fe2dd Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 8 May 2023 15:21:33 -0400 Subject: [PATCH 14/19] BaseDummyOptimizer docs and type checking --- src/raygun/torch/losses/BaseCompetentLoss.py | 6 ++--- src/raygun/torch/networks/ResidualUNet.py | 2 +- .../torch/optimizers/BaseDummyOptimizer.py | 24 +++++++++++++++---- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/raygun/torch/losses/BaseCompetentLoss.py b/src/raygun/torch/losses/BaseCompetentLoss.py index a4ead5ff..0a863cef 100644 --- a/src/raygun/torch/losses/BaseCompetentLoss.py +++ b/src/raygun/torch/losses/BaseCompetentLoss.py @@ -1,5 +1,6 @@ from raygun.evaluation.validate_affinities import run_validation import torch +from torch.utils.tensorboard import SummaryWriter from raygun.utils import passing_locals from raygun.torch.losses import GANLoss @@ -88,12 +89,11 @@ def add_log(self, writer, step) -> None: """Add an additional log to the writer, containing loss values and image examples. Args: - writer (``TODO``): + writer (``SummaryWriter``): The display writer to append the losses & images to. step (``int``): - TODO. - + The current training step. """ # add loss values diff --git a/src/raygun/torch/networks/ResidualUNet.py b/src/raygun/torch/networks/ResidualUNet.py index 66eace83..e12da3e1 100644 --- a/src/raygun/torch/networks/ResidualUNet.py +++ b/src/raygun/torch/networks/ResidualUNet.py @@ -379,7 +379,7 @@ def __init__( residual=False, norm_layer=None, add_noise=False, - ): + ) -> None: """Create a U-Net:: f_in --> f_left --------------------------->> f_right--> f_out diff --git a/src/raygun/torch/optimizers/BaseDummyOptimizer.py b/src/raygun/torch/optimizers/BaseDummyOptimizer.py index 47a91e73..732480b2 100644 --- a/src/raygun/torch/optimizers/BaseDummyOptimizer.py +++ b/src/raygun/torch/optimizers/BaseDummyOptimizer.py @@ -3,10 +3,23 @@ class BaseDummyOptimizer(torch.nn.Module): - def __init__(self, scheduler=None, scheduler_kwargs={}, **optimizers): + """Base Dummy Optimizer for training. + + Args: + scheduler (``string``, optional): + The name of the learning rate scheduler to use. Default is None. + + scheduler_kwargs (``dict``, optional): + A dictionary of keyword arguments to pass to the learning rate scheduler. Default is an empty dictionary. + + **optimizers (optional): + Keyword arguments for the optimizer(s) to use. + """ + + def __init__(self, scheduler=None, scheduler_kwargs={}, **optimizers) -> None: super().__init__() - self.schedulers = {} + self.schedulers: dict = {} for name, optimizer in optimizers.items(): setattr(self, name, optimizer) @@ -21,8 +34,8 @@ def __init__(self, scheduler=None, scheduler_kwargs={}, **optimizers): if isinstance(scheduler, str): if scheduler == "LambdaLR": - def lambda_rule(epoch): - lr_l = 1.0 - max( + def lambda_rule(epoch:int) -> float: + lr_l: float = 1.0 - max( 0, epoch + scheduler_kwargs["epoch_count"] @@ -42,6 +55,7 @@ def lambda_rule(epoch): elif scheduler is not None: self.schedulers[name] = scheduler(optimizer, **scheduler_kwargs) - def step(self): + def step(self) -> None: + """Takes a step of the optimizer(s) and the corresponding scheduler(s). """ for name, scheduler in self.schedulers.items(): scheduler.step() From d644e99726fb2e79e5f2523c0dbf870e3340c7e6 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 8 May 2023 15:25:43 -0400 Subject: [PATCH 15/19] Optimizer utils docs and type checking --- src/raygun/torch/optimizers/utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/raygun/torch/optimizers/utils.py b/src/raygun/torch/optimizers/utils.py index d40596e1..1d862424 100644 --- a/src/raygun/torch/optimizers/utils.py +++ b/src/raygun/torch/optimizers/utils.py @@ -1,9 +1,20 @@ import torch -def get_base_optimizer(optim): +def get_base_optimizer(optim) -> torch.optim.Optimizer: + """Return the base optimizer object given its name or object. + + Args: + optim (``string``, torch.optim.Optimizer``): + String or optimizer object + + Returns: + ``torch.optim.Optimizer``: + The base optimizer object. + """ + if isinstance(optim, str): - base_optim = getattr(torch.optim, optim) + base_optim: torch.optim.Optimizer = getattr(torch.optim, optim) else: base_optim = optim return base_optim From a186e28aa0b76fd9e1a68f35f9db8c3d7fdddd94 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Mon, 8 May 2023 16:11:32 -0400 Subject: [PATCH 16/19] Base Train initial docs and type checking --- src/raygun/torch/systems/BaseSystem.py | 6 +-- src/raygun/torch/train/BaseTrain.py | 58 ++++++++++++++++++++++---- src/raygun/torch/train/CycleTrain.py | 6 --- 3 files changed, 53 insertions(+), 17 deletions(-) diff --git a/src/raygun/torch/systems/BaseSystem.py b/src/raygun/torch/systems/BaseSystem.py index b8532078..e30cb1cd 100644 --- a/src/raygun/torch/systems/BaseSystem.py +++ b/src/raygun/torch/systems/BaseSystem.py @@ -20,7 +20,7 @@ class BaseSystem: def __init__( self, default_config="../default_configs/blank_conf.json", config=None - ): + ) -> None: # Add default params default_config = default_config.replace("..", parent_dir) for key, value in read_config(default_config).items(): @@ -309,7 +309,7 @@ def setup_trainer(self): self.arrays.update(self.trainer.arrays) - def build_system(self): + def build_system(self) -> None: # define our network model for training self.setup_networks() self.setup_model() @@ -317,7 +317,7 @@ def build_system(self): self.setup_datapipes() self.setup_trainer() - def train(self): + def train(self) -> None: if not hasattr(self, "trainer"): self.build_system() if hasattr(self, "train_kwargs"): diff --git a/src/raygun/torch/train/BaseTrain.py b/src/raygun/torch/train/BaseTrain.py index cd211600..50b3a64a 100644 --- a/src/raygun/torch/train/BaseTrain.py +++ b/src/raygun/torch/train/BaseTrain.py @@ -8,12 +8,54 @@ import logging -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) from raygun.utils import passing_locals, to_json - class BaseTrain(object): + """Base training class for models. + + Args: + datapipes (``dict``): + Dictionary of Gunpowder datapipes. + + batch_request (``gunpowder.BatchRequest``): + Request to use when running Gunpowder. + + model (``torch.nn.Module``): + PyTorch model to use for training. + + loss (``torch.nn.Module``): + PyTorch loss function to use for training. + + optimizer (``torch.nn.Module``): + PyTorch optimizer to use for training. + + tensorboard_path (``string``, optional): + Path to use for Tensorboard logs. Defaults to "./tensorboard/". + + log_every (``integer``, optional): + How often to log loss during training. Defaults to 20. + + checkpoint_basename (``string``, optional): + Basename to use for model checkpoints. Defaults to "./models/model". + + save_every (``intger``, optional): + How often to save a model checkpoint. Defaults to 2000. + + spawn_subprocess (``bool``, optional): + Whether to spawn a subprocess to run Gunpowder. Defaults to False. + + num_workers (``integer``, optional): + Number of workers to use with the Gunpowder PreCache node. Defaults to 11. + + cache_size (``integer``, optional): + Cache size to use with the Gunpowder PreCache node. Defaults to 50. + + snapshot_every (``integer``, optional): + How often to save a snapshot of the training volumes. Defaults to None. + """ + def __init__( self, datapipes: dict, @@ -30,27 +72,27 @@ def __init__( cache_size: int = 50, snapshot_every=None, **kwargs, - ): - kwargs = passing_locals(locals()) + ) -> None: + kwargs: dict = passing_locals(locals()) for key, value in kwargs.items(): setattr(self, key, value) - self.arrays = {} + self.arrays: dict = {} for datapipe in datapipes.values(): self.arrays.update(datapipe.arrays) - self.input_dict = {} + self.input_dict: dict = {} for array_name in inspect.signature(model.forward).parameters.keys(): if array_name != "self": self.input_dict[array_name] = self.arrays[array_name] - self.output_dict = {} + self.output_dict: dict = {} for i, array_name in enumerate(model.output_arrays): if array_name not in self.arrays.keys(): self.arrays[array_name] = gp.ArrayKey(array_name.upper()) self.output_dict[i] = self.arrays[array_name] - self.loss_input_dict = {} + self.loss_input_dict: dict = {} for i, array_name in enumerate( inspect.signature(loss.forward).parameters.keys() ): diff --git a/src/raygun/torch/train/CycleTrain.py b/src/raygun/torch/train/CycleTrain.py index 347a2044..c3b4eb8a 100644 --- a/src/raygun/torch/train/CycleTrain.py +++ b/src/raygun/torch/train/CycleTrain.py @@ -23,12 +23,6 @@ def __init__( ): super().__init__(**passing_locals(locals())) - # def prenet_pipe(self, mode: str = "train"): #TODO: Remove before 0.3.0 - # return ( - # tuple([dp.prenet_pipe(mode) for dp in self.datapipes.values()]) - # + gp.MergeProvider() - # ) # merge upstream pipelines for multiple sources - def postnet_pipe(self, mode: str = "train"): if mode == "test": stack = lambda dp: 1 From 300d9942ba866209478c792118b128812818275c Mon Sep 17 00:00:00 2001 From: brianreicher Date: Tue, 9 May 2023 14:00:29 -0400 Subject: [PATCH 17/19] BaseTrain docs and type checking --- src/raygun/torch/train/BaseTrain.py | 66 ++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/src/raygun/torch/train/BaseTrain.py b/src/raygun/torch/train/BaseTrain.py index 50b3a64a..3f4e7cf5 100644 --- a/src/raygun/torch/train/BaseTrain.py +++ b/src/raygun/torch/train/BaseTrain.py @@ -119,17 +119,50 @@ def __init__( os.makedirs(os.path.dirname(checkpoint_basename), exist_ok=True) def prenet_pipe(self, mode: str = "train"): + """Creates a pipeline that preprocesses the input data. The pre-processing pipeline is created by calling the prenet_pipe() method on all data pipes, and then merging the output streams into one using the MergeProvider() node. + + Args: + mode (``string``, optional): + The mode in which the data will be processed, defaults to "train." + + Returns: + ``tuple``: + A tuple that contains the pre-processed data from all data pipes merged using MergeProvider() node. + """ + return ( tuple([dp.prenet_pipe(mode) for dp in self.datapipes.values()]) + gp.MergeProvider() ) # merge upstream pipelines for multiple sources - def postnet_pipe(self, batch_size=1): + def postnet_pipe(self, batch_size=1) -> list: + """Creates a post-processing pipeline that is responsible for processing the output of the network. + The pipeline is created by calling the postnet_pipe() method on all data pipes and storing the output streams in a list. + + Args: + batch_size (``integer``, optional) + The batch size of the data, with a default value of 1. + + Returns: + ``list``: + A list of post-processed data from all data pipes. + """ + return [ dp.postnet_pipe(batch_size=batch_size) for dp in self.datapipes.values() ] def training_pipe(self, mode: str = "train"): + """Creates a pipeline for training the neural network. The pipeline is created by calling the prenet_pipe() method to create a pre-processing pipeline, adding a PreCache() node if mode is "train", adding the train_node to the pipeline, calling the postnet_pipe() method to create a post-processing pipeline, and adding a Snapshot() node if mode is "train" and snapshot_every is not None. + + Args: + mode (`str``, optional): + The mode in which the data will be processed, with a default value of "train." + + Returns: + The pipeline for training the network. + """ + # assemble pipeline training_pipe = self.prenet_pipe(mode) @@ -149,7 +182,7 @@ def training_pipe(self, mode: str = "train"): training_pipe += section if mode == "train" and self.snapshot_every is not None: - snapshot_names = {} + snapshot_names: dict = {} if hasattr(self, "snapshot_arrays") and self.snapshot_arrays is not None: for array in self.snapshot_arrays: snapshot_names[self.arrays[array]] = array @@ -171,7 +204,9 @@ def training_pipe(self, mode: str = "train"): return training_pipe - def print_profiling_stats(self): + def print_profiling_stats(self) -> None: + """Prints the profiling statistics for the pipeline.""" + stats = "\n" stats += "Profiling Stats\n" stats += "===============\n" @@ -185,7 +220,7 @@ def print_profiling_stats(self): stats += "MEDIAN".ljust(10) stats += "\n" - summaries = list(self.batch.profiling_stats.get_timing_summaries().items()) + summaries: list = list(self.batch.profiling_stats.get_timing_summaries().items()) summaries.sort() for (node_name, method_name), summary in summaries: @@ -206,7 +241,14 @@ def print_profiling_stats(self): print(stats) - def train(self, iter: int): + def train(self, iter: int) -> None: + """Trains the model for the specified number of iterations. + + Args: + iter (``integer``): + The number of iterations to train the model for. + """ + self.model.train() training_pipeline = self.training_pipe() with gp.build(training_pipeline): @@ -219,10 +261,20 @@ def train(self, iter: int): if i + 1 % self.log_every == 0: self.train_node.summary_writer.flush() - def test(self, mode: str = "train"): + def test(self, mode: str = "train") -> gp.Batch: + """Runs the testing mode for the model. + + Args: + mode (str): The mode to run the test in. + + Returns: + ``gp.Batch``: + The test batch. + """ + getattr(self.model, mode)() training_pipeline = self.training_pipe(mode="test") with gp.build(training_pipeline): - self.batch = training_pipeline.request_batch(self.batch_request) + self.batch: gp.Batch = training_pipeline.request_batch(self.batch_request) return self.batch From 8e6d74f542b433378d09c95634c001e4c9a5c1fd Mon Sep 17 00:00:00 2001 From: brianreicher Date: Tue, 9 May 2023 14:07:21 -0400 Subject: [PATCH 18/19] CycleTrain docs and typechecking --- src/raygun/torch/train/CycleTrain.py | 55 +++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/src/raygun/torch/train/CycleTrain.py b/src/raygun/torch/train/CycleTrain.py index c3b4eb8a..cae44f67 100644 --- a/src/raygun/torch/train/CycleTrain.py +++ b/src/raygun/torch/train/CycleTrain.py @@ -5,6 +5,48 @@ class CycleTrain(BaseTrain): + """CycleTrain object, an extension of BaseTrain for CycleGANs and CycleGAN-like models. + + Args: + datapipes (``dict``): + A dictionary containing data pipelines. + + batch_request (``gunpowder.BatchRequest``): + The batch request to make to the pipeline. + + model (``torch.nn.Module``): + The model to use for training or testing. + + loss (``torch.nn.Module``): + The loss function to use. + + optimizer (``torch.nn.Module``): + The optimizer to use for training. + tensorboard_path (``string``): + Path to store tensorboard files. Default is "./tensorboard/". + + log_every (``integer``): + Logging frequency. Default is 20. + + checkpoint_basename (``string``): + Base name of the file to store checkpoints. Default is "./models/model". + + save_every (``integer``): + Frequency of checkpoint saving. Default is 2000. + + spawn_subprocess (``bool``): + Whether or not to spawn a subprocess. Default is False. + + num_workers (``integer``): + Number of workers to use for parallelization. Default is 11. + + cache_size (``integer``): + Size of the cache. Default is 50. + + **kwargs: + Additional keyword arguments. + """ + def __init__( self, datapipes: dict, @@ -23,7 +65,18 @@ def __init__( ): super().__init__(**passing_locals(locals())) - def postnet_pipe(self, mode: str = "train"): + def postnet_pipe(self, mode: str = "train") -> list: + """Returns a list of postnet pipeline objects. + + Args: + mode (``string``): + The mode to use for the pipeline. Default is "train". + + Returns: + ``list``: + A list of postnet pipeline objects. + """ + if mode == "test": stack = lambda dp: 1 else: From 37d704d4846fe4333a06b331184b62c280cfcef3 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Thu, 12 Oct 2023 13:35:15 -0400 Subject: [PATCH 19/19] Remove manual typechecking --- src/raygun/torch/losses/BaseCompetentLoss.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/raygun/torch/losses/BaseCompetentLoss.py b/src/raygun/torch/losses/BaseCompetentLoss.py index 0a863cef..81d627b9 100644 --- a/src/raygun/torch/losses/BaseCompetentLoss.py +++ b/src/raygun/torch/losses/BaseCompetentLoss.py @@ -35,9 +35,6 @@ def set_requires_grad(self, nets:list, requires_grad=False) -> None: requires_grad (``bool``): Whether the networks require gradients or not. """ - - if not isinstance(nets, list): # TODO: remove, this should be redudant with type checking enforced - nets: list = [nets] for net in nets: if net is not None: for param in net.parameters():