diff --git a/setup.cfg b/setup.cfg index ea3314d..9e379e9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,10 @@ install_requires = scipy >= 1.7.3 python_speech_features >= 0.6 ctc_segmentation >= 1.7.3 + sentencepiece >= 0.1.96 + soundfile >= 0.11.0 + espnet >= 202211 + espnet_model_zoo >= 0.1.7 [options.packages.find] where = src \ No newline at end of file diff --git a/src/MSA_FET/dataset.py b/src/MSA_FET/dataset.py index d933faa..1b56ba3 100644 --- a/src/MSA_FET/dataset.py +++ b/src/MSA_FET/dataset.py @@ -89,7 +89,7 @@ def extract_one(row : pd.Series) -> dict: 'regression_labels_A': label_A, 'regression_labels_V': label_V, 'regression_labels_T': label_T, - 'mode': mode + 'mode': mode, } video_path = Path(dataset_dir) / 'Raw' / video_id / (clip_id + '.mp4') # TODO: file extension should be configurable assert video_path.exists(), f"Video file {video_path} does not exist" @@ -124,6 +124,7 @@ def extract_one(row : pd.Series) -> dict: assert feature_V.shape[0] == feature_T.shape[0] res['vision'] = feature_V res['audio'] = feature_A + res['align'] = align_result return res except Exception as e: logger.error(f'An error occurred while extracting features for video {video_id} clip {clip_id}') @@ -388,6 +389,7 @@ def run_dataset( 'regression_labels_A': [], 'regression_labels_V': [], 'regression_labels_T': [], + 'align': [], "mode": [] } @@ -438,6 +440,8 @@ def run_dataset( if config.get('align'): data.pop("vision_lengths") data.pop("audio_lengths") + else: + data.pop("align") # padding features for item in ['audio', 'vision', 'text', 'text_bert']: if item in data: @@ -464,11 +468,14 @@ def run_dataset( else: final_data[mode][item] = data[item][indexes] data = final_data + data['config'] = config # convert labels to numpy array # convert to pytorch tensors if return_type == 'pt': for mode in data.keys(): + if mode == 'config': + continue for key in ['audio', 'vision', 'text', 'text_bert']: if key in data[mode]: data[mode][key] = torch.from_numpy(data[mode][key]) diff --git a/src/MSA_FET/example_configs/aligned.json b/src/MSA_FET/example_configs/aligned.json index 35e69cc..dc812d9 100644 --- a/src/MSA_FET/example_configs/aligned.json +++ b/src/MSA_FET/example_configs/aligned.json @@ -37,6 +37,6 @@ "language": "en-us", "device": "cpu", "has_transcript": true, - "model_download_dir": "/home/sharing/mhs/espnet_models" + "model_download_dir": "default" } } diff --git a/src/MSA_FET/single.py b/src/MSA_FET/single.py index 8c7f5df..ec2b496 100644 --- a/src/MSA_FET/single.py +++ b/src/MSA_FET/single.py @@ -315,6 +315,11 @@ def run_single( pass else: raise ValueError(f"Invalid return type '{return_type}'.") + # save configs + final_result['config'] = self.config + # save align results + if 'align' in self.config: + final_result['align'] = align_result # save result if out_file: self._save_result(final_result, out_file)