Skip to content

Commit

Permalink
Update mypy and fix lint. (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
rom1504 authored Oct 13, 2023
1 parent 351decc commit 0bf00c5
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 10 deletions.
4 changes: 3 additions & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
black==22.3.0
mypy==0.942
mypy==1.6.0
types-requests==2.31.0.8
types-PyYAML==6.0.12.12
pylint==2.13.4
pytest-cov==3.0.0
pytest-xdist==2.5.0
Expand Down
4 changes: 2 additions & 2 deletions video2dataset/dataloader/custom_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,10 @@ class TorchDataWebdataset(DataPipeline, FluidInterfaceWithChangedDecode):
def __init__(
self,
urls: Union[List[str], str],
repeat: int = None,
repeat: Optional[int] = None,
shardshuffle: int = 10000,
sample_shuffle: int = 0,
buffer_size: int = None,
buffer_size: Optional[int] = None,
resample_prefixes: bool = False,
prefix_probs: Optional[List[float]] = None,
drop_last: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion video2dataset/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def _make_sbatch(self):
def _make_launch_cpu(
self,
):
"""Create cpu launcher"""

venv = os.environ.get("VIRTUAL_ENV")
if venv:
Expand All @@ -236,7 +237,7 @@ def _make_launch_cpu(
venv_activate = f"conda activate {conda_env}"
else:
raise ValueError("You need to specify either a virtual environment or a conda environment.")

cdir = os.path.abspath(os.path.dirname(__file__))
script = os.path.join(cdir, "slurm_executor.py")
project_root = os.path.abspath(os.path.join(cdir, ".."))
Expand Down
2 changes: 1 addition & 1 deletion video2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def video2dataset(
output_folder: str = "dataset",
output_format: str = "files",
input_format: str = "csv",
encode_formats: dict = None,
encode_formats: Optional[dict] = None,
stage: str = "download",
url_col: str = "url",
caption_col: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion video2dataset/subsamplers/audio_rate_subsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, sample_rate, encode_formats, n_audio_channels=None):

def __call__(self, streams, metadata=None):
audio_bytes = streams.pop("audio")
subsampled_bytes, subsampled_metas = [], []
subsampled_bytes = []
for aud_bytes in audio_bytes:
with tempfile.TemporaryDirectory() as tmpdir:
with open(os.path.join(tmpdir, "input.m4a"), "wb") as f:
Expand Down
10 changes: 6 additions & 4 deletions video2dataset/subsamplers/caption_subsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch import nn
from einops import rearrange
from typing import Optional

from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -41,11 +42,12 @@ class VideoBlipVisionModel(Blip2VisionModel):

def forward(
self,
pixel_values: torch.FloatTensor = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""Forward method for vision blip model"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
batch, _, time, _, _ = pixel_values.size()
Expand Down

0 comments on commit 0bf00c5

Please sign in to comment.