Skip to content

Latest commit

 

History

History
178 lines (116 loc) · 3.13 KB

README.md

File metadata and controls

178 lines (116 loc) · 3.13 KB

Moving MNIST forecasting

A little experiment using Convolutional RNNs to forecast moving MNIST digits.

from fastai.vision.all import *
from moving_mnist.models.conv_rnn import *
from moving_mnist.data import *
if torch.cuda.is_available():
    torch.cuda.set_device(0)
    print(torch.cuda.get_device_name())
Quadro RTX 8000

Install

It only uses fastai (version 2) as dependency. Check how to install at https://github.com/fastai/fastai2

Example:

We wil predict:

  • n_in: 5 images
  • n_out: 5 images
  • n_obj: 3 objects
ds = MovingMNIST(DATA_PATH, n_in=5, n_out=5, n_obj=[1,2,3])
train_tl = TfmdLists(range(500), ImageTupleTransform(ds))
valid_tl = TfmdLists(range(100), ImageTupleTransform(ds))
dls = DataLoaders.from_dsets(train_tl, valid_tl, bs=8,
                             after_batch=[Normalize.from_stats(*mnist_stats)]).cuda()

Left: Input, Right: Target

dls.show_batch()

png

StackUnstack takes cares of stacking the list of images into a fat tensor, and unstacking them at the end, we will need to modify our loss function to take a list of tensors as input and target.

model = StackUnstack(SimpleModel())

As the ImageSeq is a tuple of images, we will need to stack them to compute loss.

loss_func = StackLoss(MSELossFlat())
learn = Learner(dls, model, loss_func=loss_func, cbs=[])
learn.lr_find()
SuggestedLRs(lr_min=0.005754399299621582, lr_steep=3.0199516913853586e-05)

png

learn.fit_one_cycle(4, 1e-4)
epoch train_loss valid_loss time
0 0.915238 0.619522 00:12
1 0.669368 0.608123 00:12
2 0.570026 0.559723 00:12
3 0.528593 0.532774 00:12
p,t = learn.get_preds()

As you can see, the results is a list of 5 tensors with 100 samples each.

len(p), p[0].shape
(5, torch.Size([100, 1, 64, 64]))
def show_res(t, idx):
    im_seq = ImageSeq.create([t[i][idx] for i in range(5)])
    im_seq.show(figsize=(8,4));
k = random.randint(0,100)
show_res(t,k)
show_res(p,k)

png

png

Training Example: