From f6a4b9420f3d2368e104d268903a2e8e45aebcce Mon Sep 17 00:00:00 2001 From: brianreicher Date: Tue, 7 Nov 2023 09:39:31 -0500 Subject: [PATCH] ACLSD model --- src/raygun/torch/models/ACLSDModel.py | 57 +++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/raygun/torch/models/ACLSDModel.py diff --git a/src/raygun/torch/models/ACLSDModel.py b/src/raygun/torch/models/ACLSDModel.py new file mode 100644 index 00000000..90b251b6 --- /dev/null +++ b/src/raygun/torch/models/ACLSDModel.py @@ -0,0 +1,57 @@ +from raygun.torch.networks import UNet +from raygun.torch.networks.UNet import ConvPass +import torch +import logging + +torch.backends.cudnn.benchmark = True +logging.basicConfig(level=logging.INFO) + +# long range affs - use 20 output features +# increase number of downsampling layers for more features in the bottleneck +class ACLSDModel(torch.nn.Module): + def __init__( + self, + unet_kwargs={ + "input_nc": 1, + "ngf": 12, + "fmap_inc_factor": 6, + "downsample_factors": [(2, 2, 2), (2, 2, 2), (2, 2, 2)], + "constant_upsample": True, + }, + num_affs=3, + ): + super().__init__() + + self.unet = UNet(**unet_kwargs) + + self.aff_head = ConvPass( + unet_kwargs["ngf"], num_affs, [[1, 1, 1]], activation="Sigmoid" + ) + + self.output_arrays = ["pred_affs"] + self.data_dict = {} + + def add_log(self, writer, step): + # add loss input image examples + for name, data in self.data_dict.items(): + if len(data.shape) > 3: # pull out batch dimension if necessary + img = data[0].squeeze() + else: + img = data.squeeze() + + if len(img.shape) == 3: + mid = img.shape[0] // 2 # for 3D volume + img = img[mid] + + if ( + (img.min() < 0) and (img.min() >= -1.0) and (img.max() <= 1.0) + ): # scale img to [0,1] if necessary + img = (img * 0.5) + 0.5 + writer.add_image(name, img, global_step=step, dataformats="HW") + + def forward(self, raw): + self.data_dict.update({"raw": raw.detach()}) + z = self.unet(raw) + affs = self.aff_head(z) + + return affs