diff --git a/notebooks/camera_ready/corrupt_data/02 - Finetune Missing Data.ipynb b/notebooks/camera_ready/corrupt_data/02 - Finetune Missing Data.ipynb new file mode 100644 index 0000000..ea98122 --- /dev/null +++ b/notebooks/camera_ready/corrupt_data/02 - Finetune Missing Data.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Missing Data Reconstruction from Pretrained Embeddings\n", + "\n", + "For this example we're going to build on `01 - Finetune Virtual EVE.ipynb` and create a simpler finetuning set up.\n", + "\n", + "![Figure 1: Architectural Diagram](assets/architecture_diags_corrupt.svg)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: SunpyUserWarning: Importing sunpy.map without its extra dependencies may result in errors.\n", + "The following packages are not installed:\n", + "['mpl-animators>=1.0.0', 'reproject>=0.9.0']\n", + "To install sunpy with these dependencies use `pip install sunpy[map]` or `pip install sunpy[all]` for all extras. \n", + "If you installed sunpy via conda, please report this to the community channel: https://matrix.to/#/#sunpy:openastronomy.org [sunpy.util.sysinfo]\n", + "WARNING: SunpyUserWarning: Importing sunpy.visualization without its extra dependencies may result in errors.\n", + "The following packages are not installed:\n", + "['mpl-animators>=1.0.0']\n", + "To install sunpy with these dependencies use `pip install sunpy[visualization]` or `pip install sunpy[all]` for all extras. \n", + "If you installed sunpy via conda, please report this to the community channel: https://matrix.to/#/#sunpy:openastronomy.org [sunpy.util.sysinfo]\n" + ] + } + ], + "source": [ + "import os\n", + "import omegaconf\n", + "from sdofm.datasets import SDOMLDataModule\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = omegaconf.OmegaConf.load(\"finetune_corrupt_data.yml\")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.\n", + "[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.\n", + "[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.\n" + ] + } + ], + "source": [ + "data_module = SDOMLDataModule(\n", + " hmi_path=None,\n", + " aia_path=(\n", + " os.path.join(\n", + " cfg.data.sdoml.base_directory,\n", + " cfg.data.sdoml.sub_directory.aia,\n", + " )\n", + " if cfg.data.sdoml.sub_directory.aia\n", + " else None\n", + " ),\n", + " eve_path=None,\n", + " components=cfg.data.sdoml.components,\n", + " wavelengths=cfg.data.sdoml.wavelengths,\n", + " ions=cfg.data.sdoml.ions,\n", + " frequency=cfg.data.sdoml.frequency,\n", + " batch_size=cfg.model.opt.batch_size,\n", + " num_workers=cfg.data.num_workers,\n", + " val_months=cfg.data.month_splits.val,\n", + " test_months=cfg.data.month_splits.test,\n", + " holdout_months=cfg.data.month_splits.holdout,\n", + " cache_dir=os.path.join(\n", + " cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache\n", + " ),\n", + " min_date=cfg.data.min_date,\n", + " max_date=cfg.data.max_date,\n", + " num_frames=cfg.data.num_frames,\n", + " drop_frame_dim=cfg.data.drop_frame_dim,\n", + ")\n", + "data_module.setup()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "from sdofm.models import WrapEncoder, ConvTransformerTokensToEmbeddingNeck\n", + "from sdofm.benchmarks import reconstruction as bench_recon\n", + "import torch.nn.functional as F\n", + "from sdofm.constants import ALL_WAVELENGTHS\n", + "from sdofm import BaseModule\n", + "\n", + "class MissingDataModel(BaseModule):\n", + " def __init__(\n", + " self,\n", + " # Backbone parameters\n", + " img_size: int = 512,\n", + " patch_size: int = 16,\n", + " embed_dim: int = 128,\n", + " num_frames: int = 1,\n", + " # for finetuning\n", + " backbone: object = None,\n", + " freeze_encoder: bool = True,\n", + " # all else\n", + " *args,\n", + " **kwargs,\n", + " ):\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " self.backbone = backbone\n", + "\n", + " self.masking_ratio = 0.75\n", + " self.validation_metrics = []\n", + "\n", + " if freeze_encoder:\n", + " self.backbone.autoencoder.blocks.eval()\n", + " for param in self.backbone.autoencoder.blocks.parameters():\n", + " param.requires_grad = False\n", + "\n", + " self.simulated_corrupt_wavelength = 5\n", + "\n", + " # As this is a reconstruction task, something that the MAE\n", + " # was designed to do, we don't require the neck.\n", + " \n", + " def forward_corrupt_data_override(self, imgs, mask_ratio=0.75):\n", + " # corrupt our wavelength by setting it all to 0\n", + " imgs[:,self.simulated_corrupt_wavelength,:,:] = 0\n", + " # continue as normal\n", + " latent, mask, ids_restore = self.backbone.autoencoder.forward_encoder(imgs, mask_ratio)\n", + " pred = self.backbone.autoencoder.forward_decoder(latent, ids_restore)\n", + " loss = self.backbone.autoencoder.forward_loss(imgs, pred, mask)\n", + " return loss, pred, mask\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " # training_step defines the train loop.\n", + " x = batch\n", + " loss, x_hat, mask = self.forward_corrupt_data_override(x, mask_ratio=self.masking_ratio)\n", + " x_hat = self.backbone.autoencoder.unpatchify(x_hat)\n", + " loss = F.mse_loss(x_hat, x)\n", + " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True)\n", + " return loss\n", + " \n", + " def validation_step(self, batch, batch_idx):\n", + " x = batch\n", + " loss, x_hat, mask = self.backbone.autoencoder(x, mask_ratio=self.masking_ratio)\n", + " x_hat = self.backbone.autoencoder.unpatchify(x_hat)\n", + " loss = F.mse_loss(x_hat, x)\n", + " for i in range(x.shape[0]):\n", + " for frame in range(x.shape[2]):\n", + " self.validation_metrics.append(\n", + " bench_recon.get_metrics(\n", + " x[i, :, frame, :, :], x_hat[i, :, frame, :, :], ALL_WAVELENGTHS\n", + " )\n", + " )\n", + "\n", + " self.log(\"val_loss\", loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using Data Class\n", + "[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.\n", + "[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.\n", + "[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.\n", + "Loading checkpoint...\n", + "Done\n" + ] + } + ], + "source": [ + "from pretrain import Pretrainer\n", + "MAE = Pretrainer(cfg, logger=None, is_backbone=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "backbone = MAE.model" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "backbone_params = {}\n", + "backbone_params[\"img_size\"] = cfg.model.mae.img_size\n", + "backbone_params[\"patch_size\"] = cfg.model.mae.patch_size\n", + "backbone_params[\"embed_dim\"] = cfg.model.mae.embed_dim\n", + "backbone_params[\"num_frames\"] = cfg.model.mae.num_frames\n", + "\n", + "model = MissingDataModel(\n", + " # backbone\n", + " **backbone_params,\n", + " # backbone\n", + " backbone=backbone,\n", + " hyperparam_ignore=[\"backbone\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", + "\n", + " | Name | Type | Params | Mode \n", + "------------------------------------------\n", + "0 | backbone | MAE | 104 M | train\n", + "------------------------------------------\n", + "27.8 M Trainable params\n", + "76.7 M Non-trainable params\n", + "104 M Total params\n", + "418.215 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "462d1016120945a5b14c8a94c8b60a75", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00