Skip to content

Commit

Permalink
Save configs and align results in pickle file
Browse files Browse the repository at this point in the history
  • Loading branch information
FlameSky-S committed Feb 11, 2023
1 parent 1548bda commit f8fbd2d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ install_requires =
scipy >= 1.7.3
python_speech_features >= 0.6
ctc_segmentation >= 1.7.3
sentencepiece >= 0.1.96
soundfile >= 0.11.0
espnet >= 202211
espnet_model_zoo >= 0.1.7

[options.packages.find]
where = src
9 changes: 8 additions & 1 deletion src/MSA_FET/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def extract_one(row : pd.Series) -> dict:
'regression_labels_A': label_A,
'regression_labels_V': label_V,
'regression_labels_T': label_T,
'mode': mode
'mode': mode,
}
video_path = Path(dataset_dir) / 'Raw' / video_id / (clip_id + '.mp4') # TODO: file extension should be configurable
assert video_path.exists(), f"Video file {video_path} does not exist"
Expand Down Expand Up @@ -124,6 +124,7 @@ def extract_one(row : pd.Series) -> dict:
assert feature_V.shape[0] == feature_T.shape[0]
res['vision'] = feature_V
res['audio'] = feature_A
res['align'] = align_result
return res
except Exception as e:
logger.error(f'An error occurred while extracting features for video {video_id} clip {clip_id}')
Expand Down Expand Up @@ -388,6 +389,7 @@ def run_dataset(
'regression_labels_A': [],
'regression_labels_V': [],
'regression_labels_T': [],
'align': [],
"mode": []
}

Expand Down Expand Up @@ -438,6 +440,8 @@ def run_dataset(
if config.get('align'):
data.pop("vision_lengths")
data.pop("audio_lengths")
else:
data.pop("align")
# padding features
for item in ['audio', 'vision', 'text', 'text_bert']:
if item in data:
Expand All @@ -464,11 +468,14 @@ def run_dataset(
else:
final_data[mode][item] = data[item][indexes]
data = final_data
data['config'] = config
# convert labels to numpy array

# convert to pytorch tensors
if return_type == 'pt':
for mode in data.keys():
if mode == 'config':
continue
for key in ['audio', 'vision', 'text', 'text_bert']:
if key in data[mode]:
data[mode][key] = torch.from_numpy(data[mode][key])
Expand Down
2 changes: 1 addition & 1 deletion src/MSA_FET/example_configs/aligned.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@
"language": "en-us",
"device": "cpu",
"has_transcript": true,
"model_download_dir": "/home/sharing/mhs/espnet_models"
"model_download_dir": "default"
}
}
5 changes: 5 additions & 0 deletions src/MSA_FET/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ def run_single(
pass
else:
raise ValueError(f"Invalid return type '{return_type}'.")
# save configs
final_result['config'] = self.config
# save align results
if 'align' in self.config:
final_result['align'] = align_result
# save result
if out_file:
self._save_result(final_result, out_file)
Expand Down

0 comments on commit f8fbd2d

Please sign in to comment.