Skip to content

Commit

Permalink
Add cpu_cache parameter
Browse files Browse the repository at this point in the history
Fix #8.
  • Loading branch information
HolyWu committed Oct 3, 2021
1 parent 3090fc0 commit 8667169
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 59 deletions.
24 changes: 15 additions & 9 deletions vsbasicvsrpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def BasicVSRPP(clip: vs.VideoNode, model: int = 1, interval: int = 30, tile_x: int = 0, tile_y: int = 0, tile_pad: int = 16,
device_type: str = 'cuda', device_index: int = 0, fp16: bool = False) -> vs.VideoNode:
device_type: str = 'cuda', device_index: int = 0, fp16: bool = False, cpu_cache: bool = False) -> vs.VideoNode:
'''
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment
Expand Down Expand Up @@ -46,6 +46,8 @@ def BasicVSRPP(clip: vs.VideoNode, model: int = 1, interval: int = 30, tile_x: i
device_index: Device ordinal for the device type.
fp16: fp16 mode for faster and more lightweight inference on cards with Tensor Cores.
cpu_cache: Whether to send the intermediate features to CPU. This saves GPU memory, but slows down the inference speed.
'''
if not isinstance(clip, vs.VideoNode):
raise vs.Error('BasicVSR++: this is not a clip')
Expand Down Expand Up @@ -89,16 +91,20 @@ def BasicVSRPP(clip: vs.VideoNode, model: int = 1, interval: int = 30, tile_x: i
model_name = 'basicvsr_plusplus_ntire_decompress_track3.pth'
model_path = os.path.join(os.path.dirname(__file__), model_name)

if model < 3:
config_name = 'config012.py'
scale = 4
else:
config_name = 'config345.py'
scale = 1
spynet_path = os.path.join(os.path.dirname(__file__), 'spynet.pth')

cfg = mmcv.Config(dict(type='BasicVSR',
generator=dict(type='BasicVSRPlusPlus',
device=device,
mid_channels=64 if model < 3 else 128,
num_blocks=7 if model < 3 else 25,
is_low_res_input=True if model < 3 else False,
spynet_pretrained=spynet_path,
cpu_cache=cpu_cache if device_type == 'cuda' else False)))

config = mmcv.Config.fromfile(os.path.join(os.path.dirname(__file__), config_name))
scale = 4 if model < 3 else 1

model = build_model(config.model)
model = build_model(cfg._cfg_dict)
mmcv.runner.load_checkpoint(model, model_path, strict=True)
model.to(device)
model.eval()
Expand Down
7 changes: 2 additions & 5 deletions vsbasicvsrpp/basicvsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def forward(self, lq):
lq (Tensor): LQ Tensor with shape (n, t, c, h, w).
Returns:
dict: Output results.
Tensor: Output result.
"""
with torch.no_grad():
output = self.generator(lq)

return output
return self.generator(lq)
32 changes: 15 additions & 17 deletions vsbasicvsrpp/basicvsr_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BasicVSRPlusPlus(nn.Module):
and Alignment
Args:
device (torch.device): The destination GPU device.
mid_channels (int, optional): Channel number of the intermediate
features. Default: 64.
num_blocks (int, optional): The number of residual blocks in each
Expand All @@ -37,25 +38,25 @@ class BasicVSRPlusPlus(nn.Module):
resolution. Default: True.
spynet_pretrained (str, optional): Pre-trained model path of SPyNet.
Default: None.
cpu_cache_length (int, optional): When the length of sequence is larger
than this value, the intermediate features are sent to CPU. This
saves GPU memory, but slows down the inference speed. You can
increase this number if you have a GPU with large memory.
Default: 100.
cpu_cache (bool, optional): Whether to send the intermediate features
to CPU. This saves GPU memory, but slows down the inference speed.
Default: False.
"""

def __init__(self,
device,
mid_channels=64,
num_blocks=7,
max_residue_magnitude=10,
is_low_res_input=True,
spynet_pretrained=None,
cpu_cache_length=100):
cpu_cache=False):

super().__init__()
self.device = device
self.mid_channels = mid_channels
self.is_low_res_input = is_low_res_input
self.cpu_cache_length = cpu_cache_length
self.cpu_cache = cpu_cache

# optical flow
self.spynet = SPyNet(pretrained=spynet_pretrained)
Expand Down Expand Up @@ -172,13 +173,13 @@ def propagate(self, feats, flows, module_name):
for i, idx in enumerate(frame_idx):
feat_current = feats['spatial'][mapping_idx[idx]]
if self.cpu_cache:
feat_current = feat_current.cuda()
feat_prop = feat_prop.cuda()
feat_current = feat_current.cuda(self.device)
feat_prop = feat_prop.cuda(self.device)
# second-order deformable alignment
if i > 0 and self.is_with_alignment:
flow_n1 = flows[:, flow_idx[i], :, :, :]
if self.cpu_cache:
flow_n1 = flow_n1.cuda()
flow_n1 = flow_n1.cuda(self.device)

cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))

Expand All @@ -190,11 +191,11 @@ def propagate(self, feats, flows, module_name):
if i > 1: # second-order features
feat_n2 = feats[module_name][-2]
if self.cpu_cache:
feat_n2 = feat_n2.cuda()
feat_n2 = feat_n2.cuda(self.device)

flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
if self.cpu_cache:
flow_n2 = flow_n2.cuda()
flow_n2 = flow_n2.cuda(self.device)

flow_n2 = flow_n1 + flow_warp(flow_n2,
flow_n1.permute(0, 2, 3, 1))
Expand All @@ -212,7 +213,7 @@ def propagate(self, feats, flows, module_name):
for k in feats if k not in ['spatial', module_name]
] + [feat_prop]
if self.cpu_cache:
feat = [f.cuda() for f in feat]
feat = [f.cuda(self.device) for f in feat]

feat = torch.cat(feat, dim=1)
feat_prop = feat_prop + self.backbone[module_name](feat)
Expand Down Expand Up @@ -251,7 +252,7 @@ def upsample(self, lqs, feats):
hr.insert(0, feats['spatial'][mapping_idx[i]])
hr = torch.cat(hr, dim=1)
if self.cpu_cache:
hr = hr.cuda()
hr = hr.cuda(self.device)

hr = self.reconstruction(hr)
hr = self.lrelu(self.upsample1(hr))
Expand Down Expand Up @@ -284,9 +285,6 @@ def forward(self, lqs):

n, t, c, h, w = lqs.size()

# whether to cache the features in CPU
self.cpu_cache = False # True if t > self.cpu_cache_length else False

if self.is_low_res_input:
lqs_downsample = lqs.clone()
else:
Expand Down
14 changes: 0 additions & 14 deletions vsbasicvsrpp/config012.py

This file was deleted.

14 changes: 0 additions & 14 deletions vsbasicvsrpp/config345.py

This file was deleted.

0 comments on commit 8667169

Please sign in to comment.