Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update speech_recognition.py #182

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions templates/dataset_scripts/csv/speech_recognition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import csv
import os
from typing import Dict, List, Tuple, Iterator, Optional

import datasets
from tqdm import tqdm

import logging

_DESCRIPTION = """\
Persian portion of the common voice 13 dataset, gathered and maintained by Hezar AI.
Expand All @@ -20,15 +21,12 @@
"""

_HOMEPAGE = "https://commonvoice.mozilla.org/en/datasets"

_LICENSE = "https://creativecommons.org/publicdomain/zero/1.0/"

_BASE_URL = "https://huggingface.co/datasets/hezarai/common-voice-13-fa/resolve/main/"

_AUDIO_URL = _BASE_URL + "audio/{split}.zip"

_TRANSCRIPT_URL = _BASE_URL + "transcripts/{split}.tsv"

logger = logging.getLogger(__name__)

class CommonVoiceFaConfig(datasets.BuilderConfig):
"""BuilderConfig for CommonVoice."""
Expand All @@ -38,6 +36,8 @@ def __init__(self, **kwargs):


class CommonVoice(datasets.GeneratorBasedBuilder):
"""Dataset loader for the Persian Common Voice dataset."""

DEFAULT_WRITER_BATCH_SIZE = 1000

BUILDER_CONFIGS = [
Expand All @@ -48,7 +48,8 @@ class CommonVoice(datasets.GeneratorBasedBuilder):
)
]

def _info(self):
def _info(self) -> datasets.DatasetInfo:
"""Returns the dataset metadata."""
features = datasets.Features(
{
"client_id": datasets.Value("string"),
Expand Down Expand Up @@ -76,7 +77,8 @@ def _info(self):
version=self.config.version,
)

def _split_generators(self, dl_manager):
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
"""Returns SplitGenerators."""
splits = ("train", "dev", "test")
audio_urls = {split: _AUDIO_URL.format(split=split) for split in splits}

Expand Down Expand Up @@ -106,12 +108,35 @@ def _split_generators(self, dl_manager):

return split_generators

def _generate_examples(self, local_extracted_archive_paths, archives, transcript_path):
def _generate_examples(
self,
local_extracted_archive_paths: Optional[str],
archives: List[Iterator[Tuple[str, bytes]]],
transcript_path: str,
) -> Iterator[Tuple[str, Dict]]:
"""Yields examples."""
data_fields = list(self._info().features.keys())
metadata = self._load_metadata(transcript_path, data_fields)

for i, audio_archive in enumerate(archives):
for path, file in tqdm(audio_archive, desc=f"Processing audio files (archive {i+1})"):
_, filename = os.path.split(path)
if filename in metadata:
result = dict(metadata[filename])
# set the audio feature and the path to the extracted file
path = os.path.join(local_extracted_archive_paths[i],
path) if local_extracted_archive_paths else path
result["audio"] = {"path": path, "bytes": file.read()}
result["path"] = path
yield path, result

@staticmethod
def _load_metadata(transcript_path: str, data_fields: List[str]) -> Dict[str, Dict]:
"""Loads and validates metadata from the transcript file."""
metadata = {}
with open(transcript_path, encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
for row in tqdm(reader, desc="Reading metadata..."):
for row in tqdm(reader, desc="Reading metadata"):
if not row["path"].endswith(".mp3"):
row["path"] += ".mp3"
# accent -> accents in CV 8.0
Expand All @@ -122,16 +147,25 @@ def _generate_examples(self, local_extracted_archive_paths, archives, transcript
for field in data_fields:
if field not in row:
row[field] = ""

# Validate numeric fields
try:
row["up_votes"] = int(row["up_votes"])
row["down_votes"] = int(row["down_votes"])
except ValueError:
logger.warning(f"Invalid vote count for {row['path']}, skipping")
continue

metadata[row["path"]] = row

for i, audio_archive in enumerate(archives):
for path, file in audio_archive:
_, filename = os.path.split(path)
if filename in metadata:
result = dict(metadata[filename])
# set the audio feature and the path to the extracted file
path = os.path.join(local_extracted_archive_paths[i],
path) if local_extracted_archive_paths else path
result["audio"] = {"path": path, "bytes": file.read()}
result["path"] = path
yield path, result
return metadata

@staticmethod
def _get_audio_format(file_path: str) -> str:
"""Determines the audio format based on the file extension."""
_, ext = os.path.splitext(file_path)
return ext.lower()[1:] # Remove the leading dot

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
datasets.load_dataset(__file__)