Skip to content

Commit

Permalink
Add missing data finetuning notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
dead-water committed Jul 23, 2024
1 parent dbf21bc commit 018b1a1
Show file tree
Hide file tree
Showing 3 changed files with 748 additions and 0 deletions.
328 changes: 328 additions & 0 deletions notebooks/camera_ready/corrupt_data/02 - Finetune Missing Data.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Missing Data Reconstruction from Pretrained Embeddings\n",
"\n",
"For this example we're going to build on `01 - Finetune Virtual EVE.ipynb` and create a simpler finetuning set up.\n",
"\n",
"![Figure 1: Architectural Diagram](assets/architecture_diags_corrupt.svg)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: SunpyUserWarning: Importing sunpy.map without its extra dependencies may result in errors.\n",
"The following packages are not installed:\n",
"['mpl-animators>=1.0.0', 'reproject>=0.9.0']\n",
"To install sunpy with these dependencies use `pip install sunpy[map]` or `pip install sunpy[all]` for all extras. \n",
"If you installed sunpy via conda, please report this to the community channel: https://matrix.to/#/#sunpy:openastronomy.org [sunpy.util.sysinfo]\n",
"WARNING: SunpyUserWarning: Importing sunpy.visualization without its extra dependencies may result in errors.\n",
"The following packages are not installed:\n",
"['mpl-animators>=1.0.0']\n",
"To install sunpy with these dependencies use `pip install sunpy[visualization]` or `pip install sunpy[all]` for all extras. \n",
"If you installed sunpy via conda, please report this to the community channel: https://matrix.to/#/#sunpy:openastronomy.org [sunpy.util.sysinfo]\n"
]
}
],
"source": [
"import os\n",
"import omegaconf\n",
"from sdofm.datasets import SDOMLDataModule\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"cfg = omegaconf.OmegaConf.load(\"finetune_corrupt_data.yml\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.\n",
"[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.\n",
"[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.\n"
]
}
],
"source": [
"data_module = SDOMLDataModule(\n",
" hmi_path=None,\n",
" aia_path=(\n",
" os.path.join(\n",
" cfg.data.sdoml.base_directory,\n",
" cfg.data.sdoml.sub_directory.aia,\n",
" )\n",
" if cfg.data.sdoml.sub_directory.aia\n",
" else None\n",
" ),\n",
" eve_path=None,\n",
" components=cfg.data.sdoml.components,\n",
" wavelengths=cfg.data.sdoml.wavelengths,\n",
" ions=cfg.data.sdoml.ions,\n",
" frequency=cfg.data.sdoml.frequency,\n",
" batch_size=cfg.model.opt.batch_size,\n",
" num_workers=cfg.data.num_workers,\n",
" val_months=cfg.data.month_splits.val,\n",
" test_months=cfg.data.month_splits.test,\n",
" holdout_months=cfg.data.month_splits.holdout,\n",
" cache_dir=os.path.join(\n",
" cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache\n",
" ),\n",
" min_date=cfg.data.min_date,\n",
" max_date=cfg.data.max_date,\n",
" num_frames=cfg.data.num_frames,\n",
" drop_frame_dim=cfg.data.drop_frame_dim,\n",
")\n",
"data_module.setup()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"from sdofm.models import WrapEncoder, ConvTransformerTokensToEmbeddingNeck\n",
"from sdofm.benchmarks import reconstruction as bench_recon\n",
"import torch.nn.functional as F\n",
"from sdofm.constants import ALL_WAVELENGTHS\n",
"from sdofm import BaseModule\n",
"\n",
"class MissingDataModel(BaseModule):\n",
" def __init__(\n",
" self,\n",
" # Backbone parameters\n",
" img_size: int = 512,\n",
" patch_size: int = 16,\n",
" embed_dim: int = 128,\n",
" num_frames: int = 1,\n",
" # for finetuning\n",
" backbone: object = None,\n",
" freeze_encoder: bool = True,\n",
" # all else\n",
" *args,\n",
" **kwargs,\n",
" ):\n",
" super().__init__(*args, **kwargs)\n",
"\n",
" self.backbone = backbone\n",
"\n",
" self.masking_ratio = 0.75\n",
" self.validation_metrics = []\n",
"\n",
" if freeze_encoder:\n",
" self.backbone.autoencoder.blocks.eval()\n",
" for param in self.backbone.autoencoder.blocks.parameters():\n",
" param.requires_grad = False\n",
"\n",
" self.simulated_corrupt_wavelength = 5\n",
"\n",
" # As this is a reconstruction task, something that the MAE\n",
" # was designed to do, we don't require the neck.\n",
" \n",
" def forward_corrupt_data_override(self, imgs, mask_ratio=0.75):\n",
" # corrupt our wavelength by setting it all to 0\n",
" imgs[:,self.simulated_corrupt_wavelength,:,:] = 0\n",
" # continue as normal\n",
" latent, mask, ids_restore = self.backbone.autoencoder.forward_encoder(imgs, mask_ratio)\n",
" pred = self.backbone.autoencoder.forward_decoder(latent, ids_restore)\n",
" loss = self.backbone.autoencoder.forward_loss(imgs, pred, mask)\n",
" return loss, pred, mask\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" # training_step defines the train loop.\n",
" x = batch\n",
" loss, x_hat, mask = self.forward_corrupt_data_override(x, mask_ratio=self.masking_ratio)\n",
" x_hat = self.backbone.autoencoder.unpatchify(x_hat)\n",
" loss = F.mse_loss(x_hat, x)\n",
" self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True)\n",
" return loss\n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
" x = batch\n",
" loss, x_hat, mask = self.backbone.autoencoder(x, mask_ratio=self.masking_ratio)\n",
" x_hat = self.backbone.autoencoder.unpatchify(x_hat)\n",
" loss = F.mse_loss(x_hat, x)\n",
" for i in range(x.shape[0]):\n",
" for frame in range(x.shape[2]):\n",
" self.validation_metrics.append(\n",
" bench_recon.get_metrics(\n",
" x[i, :, frame, :, :], x_hat[i, :, frame, :, :], ALL_WAVELENGTHS\n",
" )\n",
" )\n",
"\n",
" self.log(\"val_loss\", loss)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using <class 'sdofm.datasets.SDOML.SDOMLDataModule'> Data Class\n",
"[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.\n",
"[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.\n",
"[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.\n",
"Loading checkpoint...\n",
"Done\n"
]
}
],
"source": [
"from pretrain import Pretrainer\n",
"MAE = Pretrainer(cfg, logger=None, is_backbone=True)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"backbone = MAE.model"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"backbone_params = {}\n",
"backbone_params[\"img_size\"] = cfg.model.mae.img_size\n",
"backbone_params[\"patch_size\"] = cfg.model.mae.patch_size\n",
"backbone_params[\"embed_dim\"] = cfg.model.mae.embed_dim\n",
"backbone_params[\"num_frames\"] = cfg.model.mae.num_frames\n",
"\n",
"model = MissingDataModel(\n",
" # backbone\n",
" **backbone_params,\n",
" # backbone\n",
" backbone=backbone,\n",
" hyperparam_ignore=[\"backbone\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n",
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
"\n",
" | Name | Type | Params | Mode \n",
"------------------------------------------\n",
"0 | backbone | MAE | 104 M | train\n",
"------------------------------------------\n",
"27.8 M Trainable params\n",
"76.7 M Non-trainable params\n",
"104 M Total params\n",
"418.215 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "462d1016120945a5b14c8a94c8b60a75",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a9dfb00f12274d8c9310c6298906040c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n"
]
}
],
"source": [
"from lightning.pytorch import Trainer \n",
"os.environ['PJRT_DEVICE'] = 'GPU'\n",
"trainer = Trainer(max_epochs=2, precision=32)\n",
"trainer.fit(model=model, datamodule=data_module)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 018b1a1

Please sign in to comment.