diff --git a/dynamight/deformations/optimize_deformations.py b/dynamight/deformations/optimize_deformations.py index f34914e..f37baf4 100644 --- a/dynamight/deformations/optimize_deformations.py +++ b/dynamight/deformations/optimize_deformations.py @@ -360,6 +360,7 @@ def optimize_deformations( Ivol = Ivol[0, 0] decoder_half1.vol_box = decoder_half1.vol_box//2 decoder_half2.vol_box = decoder_half2.vol_box//2 + gpu_box = decoder_half1.vol_box decoder_half1.vol_box = decoder_half1.box_size decoder_half2.vol_box = decoder_half2.box_size if mask_file: @@ -1377,8 +1378,12 @@ def optimize_deformations( if epoch % 5 == 0 or (final > finalization_epochs) or (epoch == n_epochs-1): with torch.no_grad(): + decoder_half1.vol_box = gpu_box + decoder_half2.vol_box = gpu_box V_h1 = decoder_half1.generate_consensus_volume().cpu() V_h2 = decoder_half2.generate_consensus_volume().cpu() + decoder_half1.vol_box = decoder_half1.box_size + decoder_half2.vol_box = decoder_half2.box_size gaussian_widths = torch.argmax(torch.nn.functional.softmax( decoder_half1.ampvar, dim=0), dim=0) checkpoint = {'encoder_half1': encoder_half1,