From a63a20915244febf231f015bffb8965065d5a5a7 Mon Sep 17 00:00:00 2001 From: cclauss Date: Sun, 21 Jan 2018 23:36:31 +0100 Subject: [PATCH] In __init__() save self.num_data for use in __len__() --- trainval_net.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/trainval_net.py b/trainval_net.py index ec18e05c..8b83a9f2 100644 --- a/trainval_net.py +++ b/trainval_net.py @@ -123,14 +123,15 @@ def parse_args(): class sampler(Sampler): def __init__(self, train_size, batch_size): - num_data = train_size - self.num_per_batch = int(num_data / batch_size) + self.num_data = train_size + self.num_per_batch = int(train_size / batch_size) self.batch_size = batch_size self.range = torch.arange(0,batch_size).view(1, batch_size).long() self.leftover_flag = False - if num_data % batch_size: - self.leftover = torch.arange(self.num_per_batch*batch_size, num_data).long() + if train_size % batch_size: + self.leftover = torch.arange(self.num_per_batch*batch_size, train_size).long() self.leftover_flag = True + def __iter__(self): rand_num = torch.randperm(self.num_per_batch).view(-1,1) * self.batch_size self.rand_num = rand_num.expand(self.num_per_batch, self.batch_size) + self.range @@ -143,7 +144,7 @@ def __iter__(self): return iter(self.rand_num_view) def __len__(self): - return num_data + return self.num_data if __name__ == '__main__':