Skip to content

API Reference

speech_to_text_finetune.config

Config

Bases: BaseModel

Store configuration used for finetuning

Attributes:

Name Type Description
model_id str

HF model id of a Whisper model used for finetuning

dataset_id str

HF dataset id of a Common Voice dataset version, ideally from the mozilla-foundation repo

language str

registered language string that is supported by the Common Voice dataset

repo_name str

used both for local dir and HF, "default" will create a name based on the model and language id

n_train_samples int

explicitly set how many samples to train+validate on. If -1, use all train+val data available

n_test_samples int

explicitly set how many samples to evaluate on. If -1, use all eval data available

training_hp TrainingConfig

store selective hyperparameter values from Seq2SeqTrainingArguments

Source code in src/speech_to_text_finetune/config.py
class Config(BaseModel):
    """
    Store configuration used for finetuning

    Attributes:
        model_id: HF model id of a Whisper model used for finetuning
        dataset_id: HF dataset id of a Common Voice dataset version, ideally from the mozilla-foundation repo
        language: registered language string that is supported by the Common Voice dataset
        repo_name: used both for local dir and HF, "default" will create a name based on the model and language id
        n_train_samples: explicitly set how many samples to train+validate on. If -1, use all train+val data available
        n_test_samples: explicitly set how many samples to evaluate on. If -1, use all eval data available
        training_hp: store selective hyperparameter values from Seq2SeqTrainingArguments
    """

    model_id: str
    dataset_id: str
    language: str
    repo_name: str
    n_train_samples: int
    n_test_samples: int
    training_hp: TrainingConfig

TrainingConfig

Bases: BaseModel

More info at https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments

Source code in src/speech_to_text_finetune/config.py
class TrainingConfig(BaseModel):
    """
    More info at https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments
    """

    push_to_hub: bool
    hub_private_repo: bool
    max_steps: int
    per_device_train_batch_size: int
    gradient_accumulation_steps: int
    learning_rate: float
    warmup_steps: int
    gradient_checkpointing: bool
    fp16: bool
    eval_strategy: str
    per_device_eval_batch_size: int
    predict_with_generate: bool
    generation_max_length: int
    save_steps: int
    logging_steps: int
    load_best_model_at_end: bool
    save_total_limit: int
    metric_for_best_model: str
    greater_is_better: bool

speech_to_text_finetune.data_process

DataCollatorSpeechSeq2SeqWithPadding dataclass

Data Collator class in the format expected by Seq2SeqTrainer used for processing input data and labels in batches while finetuning. More info here:

Source code in src/speech_to_text_finetune/data_process.py
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data Collator class in the format expected by Seq2SeqTrainer used for processing
    input data and labels in batches while finetuning. More info here:
    """

    processor: WhisperProcessor

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [
            {"input_features": feature["input_features"]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(
            input_features, return_tensors="pt"
        )

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyway
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

load_and_proc_hf_fleurs(language_id, n_test_samples, processor, eval_batch_size)

Load only the test split of fleurs on a specific language and process it for Whisper. Args: language_id (str): a registered language identifier from Fleurs (see https://huggingface.co/datasets/google/fleurs/blob/main/fleurs.py) n_test_samples (int): number of samples to use from the test split processor (WhisperProcessor): Processor from Whisper to process the dataset eval_batch_size (int): batch size to use for processing the dataset

Returns:

Name Type Description
DatasetDict Dataset

HF Dataset

Source code in src/speech_to_text_finetune/data_process.py
def load_and_proc_hf_fleurs(
    language_id: str,
    n_test_samples: int,
    processor: WhisperProcessor,
    eval_batch_size: int,
) -> Dataset:
    """
    Load only the test split of fleurs on a specific language and process it for Whisper.
    Args:
        language_id (str): a registered language identifier from Fleurs
            (see https://huggingface.co/datasets/google/fleurs/blob/main/fleurs.py)
        n_test_samples (int): number of samples to use from the test split
        processor (WhisperProcessor): Processor from Whisper to process the dataset
        eval_batch_size (int): batch size to use for processing the dataset

    Returns:
        DatasetDict: HF Dataset
    """
    fleurs_dataset_id = "google/fleurs"
    if proc_dataset := try_find_processed_version(fleurs_dataset_id, language_id):
        return proc_dataset

    dataset = load_dataset(
        fleurs_dataset_id, language_id, trust_remote_code=True, split="test"
    )
    dataset = load_subset_of_dataset(dataset, n_test_samples)

    dataset = dataset.rename_column(
        original_column_name="raw_transcription", new_column_name="sentence"
    )
    dataset = dataset.select_columns(["audio", "sentence"])

    save_proc_dataset_path = _get_hf_proc_dataset_path(fleurs_dataset_id, language_id)
    logger.info("Processing dataset...")
    dataset = process_dataset(
        dataset=dataset,
        processor=processor,
        batch_size=eval_batch_size,
        proc_dataset_path=save_proc_dataset_path,
    )
    logger.info(
        f"Processed dataset saved at {save_proc_dataset_path}. Future runs of {fleurs_dataset_id} will "
        f"automatically use this processed version."
    )
    return dataset

load_dataset_from_dataset_id(dataset_id, language_id=None)

This function loads a dataset, based on the dataset_id and the content of its directory (if it is a local path). Possible cases: 1. The dataset_id is a path to a local, Common Voice dataset directory.

  1. The dataset_id is a path to a local, custom dataset directory.

  2. The dataset_id is a HuggingFace dataset ID.

Parameters:

Name Type Description Default
dataset_id str

Path to a processed dataset directory or local dataset directory or HuggingFace dataset ID.

required
language_id Only used for the HF dataset case

Language identifier for the dataset (e.g., 'en' for English)

None

Returns:

Name Type Description
DatasetDict DatasetDict

A processed dataset ready for training with train/test splits

str str

Path to save the processed directory

Raises:

Type Description
ValueError

If the dataset cannot be found locally or on HuggingFace

Source code in src/speech_to_text_finetune/data_process.py
def load_dataset_from_dataset_id(
    dataset_id: str,
    language_id: str | None = None,
) -> Tuple[DatasetDict, str]:
    """
    This function loads a dataset, based on the dataset_id and the content of its directory (if it is a local path).
    Possible cases:
    1. The dataset_id is a path to a local, Common Voice dataset directory.

    2. The dataset_id is a path to a local, custom dataset directory.

    3. The dataset_id is a HuggingFace dataset ID.

    Args:
        dataset_id: Path to a processed dataset directory or local dataset directory or HuggingFace dataset ID.
        language_id (Only used for the HF dataset case): Language identifier for the dataset (e.g., 'en' for English)

    Returns:
        DatasetDict: A processed dataset ready for training with train/test splits
        str: Path to save the processed directory

    Raises:
        ValueError: If the dataset cannot be found locally or on HuggingFace
    """
    try:
        dataset = _load_local_common_voice(dataset_id)
        return dataset, _get_local_proc_dataset_path(dataset_id)
    except FileNotFoundError:
        pass

    try:
        dataset = _load_custom_dataset(dataset_id)
        return dataset, _get_local_proc_dataset_path(dataset_id)
    except FileNotFoundError:
        pass

    try:
        dataset = _load_hf_common_voice(dataset_id, language_id)
        return dataset, _get_hf_proc_dataset_path(dataset_id, language_id)
    except HFValidationError:
        pass
    except FileNotFoundError:
        pass

    raise ValueError(
        f"Could not find dataset {dataset_id}, neither locally nor at HuggingFace. "
        f"If its a private repo, make sure you are logged in locally."
    )

process_dataset(dataset, processor, batch_size, proc_dataset_path)

Process dataset to the expected format by a Whisper model and then save it locally for future use.

Source code in src/speech_to_text_finetune/data_process.py
def process_dataset(
    dataset: DatasetDict | Dataset,
    processor: WhisperProcessor,
    batch_size: int,
    proc_dataset_path: str,
) -> DatasetDict | Dataset:
    """
    Process dataset to the expected format by a Whisper model and then save it locally for future use.
    """
    # Create a new column that consists of the resampled audio samples in the right sample rate for whisper
    dataset = dataset.cast_column(
        "audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)
    )

    dataset = dataset.map(
        _process_inputs_and_labels_for_whisper,
        fn_kwargs={"processor": processor},
        remove_columns=dataset.column_names["train"]
        if "train" in dataset.column_names
        else None,
        batched=True,
        batch_size=batch_size,
        num_proc=1,
    )

    dataset = dataset.filter(
        _is_audio_in_length_range,
        input_columns=["input_length"],
        fn_kwargs={"max_input_length": 30.0},
        num_proc=1,
    )
    dataset = dataset.filter(
        _are_labels_in_length_range,
        input_columns=["labels"],
        fn_kwargs={"max_label_length": 448},
        num_proc=1,
    )

    proc_dataset_path = Path(proc_dataset_path)
    Path.mkdir(proc_dataset_path, parents=True, exist_ok=True)
    dataset.save_to_disk(proc_dataset_path)
    return dataset

try_find_processed_version(dataset_id, language_id=None)

Try to load a processed version of the dataset if it exists locally. Check if: 1. The dataset_id is a local path to an already processed dataset directory. or 2. The dataset_id is a path to a local dataset, but a processed version already exists locally. or 3. The dataset_id is a HuggingFace dataset ID, but a processed version already exists locally.

Source code in src/speech_to_text_finetune/data_process.py
def try_find_processed_version(
    dataset_id: str, language_id: str | None = None
) -> DatasetDict | Dataset | None:
    """
    Try to load a processed version of the dataset if it exists locally. Check if:
    1. The dataset_id is a local path to an already processed dataset directory.
    or
    2. The dataset_id is a path to a local dataset, but a processed version already exists locally.
    or
    3. The dataset_id is a HuggingFace dataset ID, but a processed version already exists locally.
    """
    if Path(dataset_id).name == PROC_DATASET_DIR and Path(dataset_id).is_dir():
        if (
            Path(dataset_id + "/train").is_dir()
            and Path(dataset_id + "/test").is_dir()
            and Path(dataset_id + "/dataset_dict.json").is_file()
        ):
            return load_from_disk(dataset_id)
        else:
            raise FileNotFoundError("Processed dataset is incomplete.")

    proc_dataset_path = _get_local_proc_dataset_path(dataset_id)
    if Path(proc_dataset_path).is_dir():
        return load_from_disk(proc_dataset_path)

    hf_proc_dataset_path = _get_hf_proc_dataset_path(dataset_id, language_id)
    if Path(hf_proc_dataset_path).is_dir():
        logger.info(
            f"Found processed dataset version at {hf_proc_dataset_path} of HF dataset {dataset_id}. "
            f"Loading it directly and skipping processing again the original version."
        )
        return load_from_disk(hf_proc_dataset_path)

    return None

speech_to_text_finetune.finetune_whisper

run_finetuning(config_path='config.yaml')

Complete pipeline for preprocessing the Common Voice dataset and then finetuning a Whisper model on it.

Parameters:

Name Type Description Default
config_path str

yaml filepath that follows the format defined in config.py

'config.yaml'

Returns:

Type Description
Tuple[Dict, Dict]

Tuple[Dict, Dict]: evaluation metrics from the baseline and the finetuned models

Source code in src/speech_to_text_finetune/finetune_whisper.py
def run_finetuning(
    config_path: str = "config.yaml",
) -> Tuple[Dict, Dict]:
    """
    Complete pipeline for preprocessing the Common Voice dataset and then finetuning a Whisper model on it.

    Args:
        config_path (str): yaml filepath that follows the format defined in config.py

    Returns:
        Tuple[Dict, Dict]: evaluation metrics from the baseline and the finetuned models
    """
    cfg = load_config(config_path)

    language_id = TO_LANGUAGE_CODE.get(cfg.language.lower())
    if not language_id:
        raise ValueError(
            f"\nThis language is not inherently supported by this Whisper model. However you can still “teach” Whisper "
            f"the language of your choice!\nVisit https://glottolog.org/, find which language is most closely "
            f"related to {cfg.language} from the list of supported languages below, and update your config file with "
            f"that language.\n{json.dumps(TO_LANGUAGE_CODE, indent=4, sort_keys=True)}."
        )

    if cfg.repo_name == "default":
        cfg.repo_name = f"{cfg.model_id.split('/')[1]}-{language_id}"
    local_output_dir = f"./artifacts/{cfg.repo_name}"

    logger.info(f"Finetuning starts soon, results saved locally at {local_output_dir}")
    hf_repo_name = ""
    if cfg.training_hp.push_to_hub:
        hf_username = get_hf_username()
        hf_repo_name = f"{hf_username}/{cfg.repo_name}"
        logger.info(
            f"Results will also be uploaded in HF at {hf_repo_name}. "
            f"Private repo is set to {cfg.training_hp.hub_private_repo}."
        )

    device = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
    logger.info(
        f"Loading {cfg.model_id} on {device} and configuring it for {cfg.language}."
    )
    processor = WhisperProcessor.from_pretrained(
        cfg.model_id, language=cfg.language, task="transcribe"
    )
    model = WhisperForConditionalGeneration.from_pretrained(cfg.model_id)

    # disable cache during training since it's incompatible with gradient checkpointing
    model.config.use_cache = False
    # set language and task for generation during inference and re-enable cache
    model.generate = partial(
        model.generate, language=cfg.language.lower(), task="transcribe", use_cache=True
    )

    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

    training_args = Seq2SeqTrainingArguments(
        output_dir=local_output_dir,
        hub_model_id=hf_repo_name,
        report_to=["tensorboard"],
        **cfg.training_hp.model_dump(),
    )

    if proc_dataset := try_find_processed_version(
        dataset_id=cfg.dataset_id, language_id=language_id
    ):
        logger.info(
            f"Loading processed dataset version of {cfg.dataset_id} and skipping processing."
        )
        dataset = proc_dataset
        dataset["train"] = load_subset_of_dataset(dataset["train"], cfg.n_train_samples)
        dataset["test"] = load_subset_of_dataset(dataset["test"], cfg.n_test_samples)
    else:
        logger.info(f"Loading {cfg.dataset_id}. Language selected {cfg.language}")
        dataset, save_proc_dataset_dir = load_dataset_from_dataset_id(
            dataset_id=cfg.dataset_id,
            language_id=language_id,
        )
        dataset["train"] = load_subset_of_dataset(dataset["train"], cfg.n_train_samples)
        dataset["test"] = load_subset_of_dataset(dataset["test"], cfg.n_test_samples)
        logger.info("Processing dataset...")
        dataset = process_dataset(
            dataset=dataset,
            processor=processor,
            batch_size=cfg.training_hp.per_device_train_batch_size,
            proc_dataset_path=save_proc_dataset_dir,
        )
        logger.info(
            f"Processed dataset saved at {save_proc_dataset_dir}. Future runs of {cfg.dataset_id} will "
            f"automatically use this processed version."
        )

    wer = evaluate.load("wer")
    cer = evaluate.load("cer")

    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        data_collator=data_collator,
        compute_metrics=partial(
            compute_wer_cer_metrics,
            processor=processor,
            wer=wer,
            cer=cer,
            normalizer=BasicTextNormalizer(),
        ),
        processing_class=processor.feature_extractor,
    )

    processor.save_pretrained(training_args.output_dir)

    logger.info(
        f"Before finetuning, run evaluation on the baseline model {cfg.model_id} to easily compare performance"
        f" before and after finetuning"
    )
    baseline_eval_results = trainer.evaluate()
    logger.info(f"Baseline evaluation complete. Results:\n\t {baseline_eval_results}")

    logger.info(
        f"Start finetuning job on {dataset['train'].num_rows} audio samples. Monitor training metrics in real time in "
        f"a local tensorboard server by running in a new terminal: tensorboard --logdir {training_args.output_dir}/runs"
    )
    try:
        trainer.train()
    except KeyboardInterrupt:
        logger.info("Stopping the finetuning job prematurely...")
    else:
        logger.info("Finetuning job complete.")

    logger.info(f"Start evaluation on {dataset['test'].num_rows} audio samples.")
    eval_results = trainer.evaluate()
    logger.info(f"Evaluation complete. Results:\n\t {eval_results}")

    if cfg.training_hp.push_to_hub:
        logger.info(f"Uploading model and eval results to HuggingFace: {hf_repo_name}")
        trainer.push_to_hub()
        upload_custom_hf_model_card(
            hf_repo_name=hf_repo_name,
            model_id=cfg.model_id,
            dataset_id=cfg.dataset_id,
            language_id=language_id,
            language=cfg.language,
            n_train_samples=dataset["train"].num_rows,
            n_eval_samples=dataset["test"].num_rows,
            baseline_eval_results=baseline_eval_results,
            ft_eval_results=eval_results,
        )

    logger.info(f"Find your final, best performing model at {local_output_dir}")
    return baseline_eval_results, eval_results

speech_to_text_finetune.utils

compute_wer_cer_metrics(pred, processor, wer, cer, normalizer)

Word Error Rate (wer) is a metric that measures the ratio of errors the ASR model makes given a transcript to the total words spoken. Lower is better. Character Error Rate (cer) is similar to wer, but operates on character instead of word. This metric is better suited for languages with no concept of "word" like Chinese or Japanese. Lower is better.

More info: https://huggingface.co/learn/audio-course/en/chapter5/fine-tuning#evaluation-metrics

Note 1: WER/CER can be larger than 1.0, if the number of insertions I is larger than the number of correct words C. Note 2: WER/CER doesn't tell the whole story and is not fully representative of the quality of the ASR model.

Parameters:

Name Type Description Default
pred EvalPrediction

Transformers object that holds predicted tokens and ground truth labels

required
processor WhisperProcessor

Whisper processor used to decode tokens to strings

required
wer EvaluationModule

module that calls the computing function for WER

required
cer EvaluationModule

module that calls the computing function for CER

required
normalizer BasicTextNormalizer

Normalizer from Whisper

required

Returns: wer (Dict): computed WER metric

Source code in src/speech_to_text_finetune/utils.py
def compute_wer_cer_metrics(
    pred: EvalPrediction,
    processor: WhisperProcessor,
    wer: EvaluationModule,
    cer: EvaluationModule,
    normalizer: BasicTextNormalizer,
) -> Dict:
    """
    Word Error Rate (wer) is a metric that measures the ratio of errors the ASR model makes given a transcript to the
    total words spoken. Lower is better.
    Character Error Rate (cer) is similar to wer, but operates on character instead of word. This metric is better
    suited for languages with no concept of "word" like Chinese or Japanese. Lower is better.

    More info: https://huggingface.co/learn/audio-course/en/chapter5/fine-tuning#evaluation-metrics

    Note 1: WER/CER can be larger than 1.0, if the number of insertions I is larger than the number of correct words C.
    Note 2: WER/CER doesn't tell the whole story and is not fully representative of the quality of the ASR model.

    Args:
        pred (EvalPrediction): Transformers object that holds predicted tokens and ground truth labels
        processor (WhisperProcessor): Whisper processor used to decode tokens to strings
        wer (EvaluationModule): module that calls the computing function for WER
        cer (EvaluationModule): module that calls the computing function for CER
        normalizer (BasicTextNormalizer): Normalizer from Whisper
    Returns:
        wer (Dict): computed WER metric
    """

    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # compute orthographic wer
    wer_ortho = 100 * wer.compute(predictions=pred_str, references=label_str)
    cer_ortho = 100 * cer.compute(predictions=pred_str, references=label_str)

    # compute normalised WER
    pred_str_norm = [normalizer(pred) for pred in pred_str]
    label_str_norm = [normalizer(label) for label in label_str]
    # filtering step to only evaluate the samples that correspond to non-zero references:
    pred_str_norm = [
        pred_str_norm[i]
        for i in range(len(pred_str_norm))
        if len(label_str_norm[i]) > 0
    ]
    label_str_norm = [
        label_str_norm[i]
        for i in range(len(label_str_norm))
        if len(label_str_norm[i]) > 0
    ]

    wer = 100 * wer.compute(predictions=pred_str_norm, references=label_str_norm)
    cer = 100 * cer.compute(predictions=pred_str_norm, references=label_str_norm)

    return {"wer_ortho": wer_ortho, "wer": wer, "cer_ortho": cer_ortho, "cer": cer}

upload_custom_hf_model_card(hf_repo_name, model_id, dataset_id, language_id, language, n_train_samples, n_eval_samples, baseline_eval_results, ft_eval_results)

Create and upload a custom Model Card (https://huggingface.co/docs/hub/model-cards) to the Hugging Face repo of the finetuned model that highlights the evaluation results before and after finetuning.

Source code in src/speech_to_text_finetune/utils.py
def upload_custom_hf_model_card(
    hf_repo_name: str,
    model_id: str,
    dataset_id: str,
    language_id: str,
    language: str,
    n_train_samples: int,
    n_eval_samples: int,
    baseline_eval_results: Dict,
    ft_eval_results: Dict,
) -> None:
    """
    Create and upload a custom Model Card (https://huggingface.co/docs/hub/model-cards) to the Hugging Face repo
    of the finetuned model that highlights the evaluation results before and after finetuning.
    """
    card_metadata = ModelCardData(
        model_name=f"Finetuned {model_id} on {language}",
        base_model=model_id,
        datasets=[dataset_id.split("/")[-1]],
        language=language_id,
        license="apache-2.0",
        library_name="transformers",
        eval_results=[
            EvalResult(
                task_type="automatic-speech-recognition",
                task_name="Speech-to-Text",
                dataset_type="common_voice",
                dataset_name=f"Common Voice ({language})",
                metric_type="wer",
                metric_value=round(ft_eval_results["eval_wer"], 3),
            )
        ],
    )
    content = f"""
---
{card_metadata.to_yaml()}
---

# Finetuned {model_id} on {n_train_samples} {language} training audio samples from {dataset_id}.

This model was created from the Mozilla.ai Blueprint:
[speech-to-text-finetune](https://github.com/mozilla-ai/speech-to-text-finetune).

## Evaluation results on {n_eval_samples} audio samples of {language}:

### Baseline model (before finetuning) on {language}
- Word Error Rate (Normalized): {round(baseline_eval_results["eval_wer"], 3)}
- Word Error Rate (Orthographic): {round(baseline_eval_results["eval_wer_ortho"], 3)}
- Character Error Rate (Normalized): {round(baseline_eval_results["eval_cer"], 3)}
- Character Error Rate (Orthographic): {round(baseline_eval_results["eval_cer_ortho"], 3)}
- Loss: {round(baseline_eval_results["eval_loss"], 3)}

### Finetuned model (after finetuning) on {language}
- Word Error Rate (Normalized): {round(ft_eval_results["eval_wer"], 3)}
- Word Error Rate (Orthographic): {round(ft_eval_results["eval_wer_ortho"], 3)}
- Character Error Rate (Normalized): {round(ft_eval_results["eval_cer"], 3)}
- Character Error Rate (Orthographic): {round(ft_eval_results["eval_cer_ortho"], 3)}
- Loss: {round(ft_eval_results["eval_loss"], 3)}
"""

    card = ModelCard(content)
    card.push_to_hub(hf_repo_name)

speech_to_text_finetune.make_custom_dataset_app

save_text_audio_to_file(audio_input, sentence, dataset_dir, is_train_sample)

Save the audio recording in a .wav file using the index of the text sentence in the filename. And save the associated text sentence in a .csv file using the same index.

Parameters:

Name Type Description Default
audio_input Audio

Gradio audio object to be converted to audio data and then saved to a .wav file

required
sentence str

The text sentence that will be associated with the audio

required
dataset_dir str

The dataset directory path to store the indexed sentences and the associated audio files

required
is_train_sample bool

Whether to save the text-recording pair to the train or test dataset

required

Returns:

Name Type Description
str str

Status text for Gradio app

None None

Returning None here will reset the audio module to record again from scratch

Source code in src/speech_to_text_finetune/make_custom_dataset_app.py
def save_text_audio_to_file(
    audio_input: gr.Audio,
    sentence: str,
    dataset_dir: str,
    is_train_sample: bool,
) -> Tuple[str, None]:
    """
    Save the audio recording in a .wav file using the index of the text sentence in the filename.
    And save the associated text sentence in a .csv file using the same index.

    Args:
        audio_input (gr.Audio): Gradio audio object to be converted to audio data and then saved to a .wav file
        sentence (str): The text sentence that will be associated with the audio
        dataset_dir (str): The dataset directory path to store the indexed sentences and the associated audio files
        is_train_sample (bool): Whether to save the text-recording pair to the train or test dataset

    Returns:
        str: Status text for Gradio app
        None: Returning None here will reset the audio module to record again from scratch
    """
    Path(f"{dataset_dir}/train").mkdir(parents=True, exist_ok=True)
    Path(f"{dataset_dir}/train/clips").mkdir(parents=True, exist_ok=True)
    Path(f"{dataset_dir}/test").mkdir(parents=True, exist_ok=True)
    Path(f"{dataset_dir}/test/clips").mkdir(parents=True, exist_ok=True)

    data_path = (
        Path(f"{dataset_dir}/train/")
        if is_train_sample
        else Path(f"{dataset_dir}/test/")
    )
    text_path = Path(f"{data_path}/text.csv")
    if text_path.is_file():
        df = pd.read_csv(text_path)
    else:
        df = pd.DataFrame(columns=["index", "sentence"])

    index = len(df)
    text_df = pd.concat(
        [df, pd.DataFrame([{"index": index, "sentence": sentence}])],
        ignore_index=True,
    )
    text_df = text_df.sort_values(by="index")
    text_df.to_csv(text_path, index=False)

    audio_filepath = f"{data_path}/clips/rec_{index}.wav"

    sr, data = audio_input
    sf.write(file=audio_filepath, data=data, samplerate=sr)

    return (
        f"""✅ Updated {text_path} \n✅ Saved recording to {audio_filepath}""",
        None,
    )