diff --git a/egs/librispeech/ASR/whisper/train.py b/egs/librispeech/ASR/whisper/train.py index 6ccb8d363..bd6b27d99 100755 --- a/egs/librispeech/ASR/whisper/train.py +++ b/egs/librispeech/ASR/whisper/train.py @@ -23,7 +23,8 @@ torchrun --nproc_per_node 8 ./whisper/train.py \ --max-duration 200 \ --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ - --manifest-dir data/fbank_whisper \ + --full-libri True \ + --manifest-dir data/fbank_whisper_80D \ --deepspeed \ --deepspeed_config ./whisper/ds_config_zero1.json @@ -31,7 +32,8 @@ torchrun --nproc_per_node 8 ./whisper/train.py \ torchrun --nproc_per_node 8 ./whisper/train.py \ --max-duration 200 \ --exp-dir whisper/exp_medium \ - --manifest-dir data/fbank_whisper \ + --full-libri True \ + --manifest-dir data/fbank_whisper_80D \ --base-lr 1e-5 \ --model-name medium """ @@ -53,7 +55,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn import whisper -from asr_datamodule import AishellAsrDataModule +from asr_datamodule import LibriSpeechAsrDataModule from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from label_smoothing import LabelSmoothingLoss from lhotse import CutSet, load_manifest @@ -147,7 +149,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "tiny"], + choices=["large-v2", "large-v3", "medium", "medium.en", "small", "small.en", "tiny", "tiny.en"], help="""The model name to use. """, ) @@ -450,8 +452,7 @@ def compute_loss( batch_idx_train = params.batch_idx_train texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [text.replace(" ", "") for text in texts] + texts = [t[0] + t[1:].lower() for t in texts] text_tokens_list = [ list(tokenizer.sot_sequence_including_notimestamps) @@ -744,7 +745,7 @@ def run(rank, world_size, args): tokenizer = whisper.tokenizer.get_tokenizer( model.is_multilingual, num_languages=model.num_languages, - language="zh", + language="en", task="transcribe", ) @@ -800,7 +801,19 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - aishell = AishellAsrDataModule(args) + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + if c.duration < 1.0 or c.duration > 20.0: + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -809,8 +822,16 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = aishell.train_dataloaders(aishell.train_cuts()) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + + # do this to prevent Whisper throwing the length mismatch error + valid_cuts = valid_cuts.filter(remove_short_and_long_utt) + valid_dl = librispeech.valid_dataloaders(valid_cuts) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -911,7 +932,7 @@ def display_and_save_batch( def main(): parser = get_parser() - AishellAsrDataModule.add_arguments(parser) + LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)