-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add missing data finetuning notebook
- Loading branch information
1 parent
dbf21bc
commit 018b1a1
Showing
3 changed files
with
748 additions
and
0 deletions.
There are no files selected for viewing
328 changes: 328 additions & 0 deletions
328
notebooks/camera_ready/corrupt_data/02 - Finetune Missing Data.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,328 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Missing Data Reconstruction from Pretrained Embeddings\n", | ||
"\n", | ||
"For this example we're going to build on `01 - Finetune Virtual EVE.ipynb` and create a simpler finetuning set up.\n", | ||
"\n", | ||
"![Figure 1: Architectural Diagram](assets/architecture_diags_corrupt.svg)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"WARNING: SunpyUserWarning: Importing sunpy.map without its extra dependencies may result in errors.\n", | ||
"The following packages are not installed:\n", | ||
"['mpl-animators>=1.0.0', 'reproject>=0.9.0']\n", | ||
"To install sunpy with these dependencies use `pip install sunpy[map]` or `pip install sunpy[all]` for all extras. \n", | ||
"If you installed sunpy via conda, please report this to the community channel: https://matrix.to/#/#sunpy:openastronomy.org [sunpy.util.sysinfo]\n", | ||
"WARNING: SunpyUserWarning: Importing sunpy.visualization without its extra dependencies may result in errors.\n", | ||
"The following packages are not installed:\n", | ||
"['mpl-animators>=1.0.0']\n", | ||
"To install sunpy with these dependencies use `pip install sunpy[visualization]` or `pip install sunpy[all]` for all extras. \n", | ||
"If you installed sunpy via conda, please report this to the community channel: https://matrix.to/#/#sunpy:openastronomy.org [sunpy.util.sysinfo]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import os\n", | ||
"import omegaconf\n", | ||
"from sdofm.datasets import SDOMLDataModule\n", | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"cfg = omegaconf.OmegaConf.load(\"finetune_corrupt_data.yml\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 29, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.\n", | ||
"[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.\n", | ||
"[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"data_module = SDOMLDataModule(\n", | ||
" hmi_path=None,\n", | ||
" aia_path=(\n", | ||
" os.path.join(\n", | ||
" cfg.data.sdoml.base_directory,\n", | ||
" cfg.data.sdoml.sub_directory.aia,\n", | ||
" )\n", | ||
" if cfg.data.sdoml.sub_directory.aia\n", | ||
" else None\n", | ||
" ),\n", | ||
" eve_path=None,\n", | ||
" components=cfg.data.sdoml.components,\n", | ||
" wavelengths=cfg.data.sdoml.wavelengths,\n", | ||
" ions=cfg.data.sdoml.ions,\n", | ||
" frequency=cfg.data.sdoml.frequency,\n", | ||
" batch_size=cfg.model.opt.batch_size,\n", | ||
" num_workers=cfg.data.num_workers,\n", | ||
" val_months=cfg.data.month_splits.val,\n", | ||
" test_months=cfg.data.month_splits.test,\n", | ||
" holdout_months=cfg.data.month_splits.holdout,\n", | ||
" cache_dir=os.path.join(\n", | ||
" cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache\n", | ||
" ),\n", | ||
" min_date=cfg.data.min_date,\n", | ||
" max_date=cfg.data.max_date,\n", | ||
" num_frames=cfg.data.num_frames,\n", | ||
" drop_frame_dim=cfg.data.drop_frame_dim,\n", | ||
")\n", | ||
"data_module.setup()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 32, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sdofm.models import WrapEncoder, ConvTransformerTokensToEmbeddingNeck\n", | ||
"from sdofm.benchmarks import reconstruction as bench_recon\n", | ||
"import torch.nn.functional as F\n", | ||
"from sdofm.constants import ALL_WAVELENGTHS\n", | ||
"from sdofm import BaseModule\n", | ||
"\n", | ||
"class MissingDataModel(BaseModule):\n", | ||
" def __init__(\n", | ||
" self,\n", | ||
" # Backbone parameters\n", | ||
" img_size: int = 512,\n", | ||
" patch_size: int = 16,\n", | ||
" embed_dim: int = 128,\n", | ||
" num_frames: int = 1,\n", | ||
" # for finetuning\n", | ||
" backbone: object = None,\n", | ||
" freeze_encoder: bool = True,\n", | ||
" # all else\n", | ||
" *args,\n", | ||
" **kwargs,\n", | ||
" ):\n", | ||
" super().__init__(*args, **kwargs)\n", | ||
"\n", | ||
" self.backbone = backbone\n", | ||
"\n", | ||
" self.masking_ratio = 0.75\n", | ||
" self.validation_metrics = []\n", | ||
"\n", | ||
" if freeze_encoder:\n", | ||
" self.backbone.autoencoder.blocks.eval()\n", | ||
" for param in self.backbone.autoencoder.blocks.parameters():\n", | ||
" param.requires_grad = False\n", | ||
"\n", | ||
" self.simulated_corrupt_wavelength = 5\n", | ||
"\n", | ||
" # As this is a reconstruction task, something that the MAE\n", | ||
" # was designed to do, we don't require the neck.\n", | ||
" \n", | ||
" def forward_corrupt_data_override(self, imgs, mask_ratio=0.75):\n", | ||
" # corrupt our wavelength by setting it all to 0\n", | ||
" imgs[:,self.simulated_corrupt_wavelength,:,:] = 0\n", | ||
" # continue as normal\n", | ||
" latent, mask, ids_restore = self.backbone.autoencoder.forward_encoder(imgs, mask_ratio)\n", | ||
" pred = self.backbone.autoencoder.forward_decoder(latent, ids_restore)\n", | ||
" loss = self.backbone.autoencoder.forward_loss(imgs, pred, mask)\n", | ||
" return loss, pred, mask\n", | ||
"\n", | ||
" def training_step(self, batch, batch_idx):\n", | ||
" # training_step defines the train loop.\n", | ||
" x = batch\n", | ||
" loss, x_hat, mask = self.forward_corrupt_data_override(x, mask_ratio=self.masking_ratio)\n", | ||
" x_hat = self.backbone.autoencoder.unpatchify(x_hat)\n", | ||
" loss = F.mse_loss(x_hat, x)\n", | ||
" self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True)\n", | ||
" return loss\n", | ||
" \n", | ||
" def validation_step(self, batch, batch_idx):\n", | ||
" x = batch\n", | ||
" loss, x_hat, mask = self.backbone.autoencoder(x, mask_ratio=self.masking_ratio)\n", | ||
" x_hat = self.backbone.autoencoder.unpatchify(x_hat)\n", | ||
" loss = F.mse_loss(x_hat, x)\n", | ||
" for i in range(x.shape[0]):\n", | ||
" for frame in range(x.shape[2]):\n", | ||
" self.validation_metrics.append(\n", | ||
" bench_recon.get_metrics(\n", | ||
" x[i, :, frame, :, :], x_hat[i, :, frame, :, :], ALL_WAVELENGTHS\n", | ||
" )\n", | ||
" )\n", | ||
"\n", | ||
" self.log(\"val_loss\", loss)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using <class 'sdofm.datasets.SDOML.SDOMLDataModule'> Data Class\n", | ||
"[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.\n", | ||
"[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.\n", | ||
"[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.\n", | ||
"Loading checkpoint...\n", | ||
"Done\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from pretrain import Pretrainer\n", | ||
"MAE = Pretrainer(cfg, logger=None, is_backbone=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 34, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"backbone = MAE.model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 35, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"backbone_params = {}\n", | ||
"backbone_params[\"img_size\"] = cfg.model.mae.img_size\n", | ||
"backbone_params[\"patch_size\"] = cfg.model.mae.patch_size\n", | ||
"backbone_params[\"embed_dim\"] = cfg.model.mae.embed_dim\n", | ||
"backbone_params[\"num_frames\"] = cfg.model.mae.num_frames\n", | ||
"\n", | ||
"model = MissingDataModel(\n", | ||
" # backbone\n", | ||
" **backbone_params,\n", | ||
" # backbone\n", | ||
" backbone=backbone,\n", | ||
" hyperparam_ignore=[\"backbone\"],\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 36, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n", | ||
"GPU available: True (cuda), used: True\n", | ||
"TPU available: False, using: 0 TPU cores\n", | ||
"HPU available: False, using: 0 HPUs\n", | ||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", | ||
"\n", | ||
" | Name | Type | Params | Mode \n", | ||
"------------------------------------------\n", | ||
"0 | backbone | MAE | 104 M | train\n", | ||
"------------------------------------------\n", | ||
"27.8 M Trainable params\n", | ||
"76.7 M Non-trainable params\n", | ||
"104 M Total params\n", | ||
"418.215 Total estimated model params size (MB)\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "462d1016120945a5b14c8a94c8b60a75", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Sanity Checking: | | 0/? [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "a9dfb00f12274d8c9310c6298906040c", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Training: | | 0/? [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from lightning.pytorch import Trainer \n", | ||
"os.environ['PJRT_DEVICE'] = 'GPU'\n", | ||
"trainer = Trainer(max_epochs=2, precision=32)\n", | ||
"trainer.fit(model=model, datamodule=data_module)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.