From 2b2da212087fa40e2cbbd01bb42c238a0bd7414b Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 27 Feb 2024 18:02:33 +0800 Subject: [PATCH] update the training script --- egs/mls/ASR/zipformer/train.py | 101 +++++++++++++++------------------ 1 file changed, 46 insertions(+), 55 deletions(-) diff --git a/egs/mls/ASR/zipformer/train.py b/egs/mls/ASR/zipformer/train.py index 3ccf7d2f1..0f22ff9cb 100755 --- a/egs/mls/ASR/zipformer/train.py +++ b/egs/mls/ASR/zipformer/train.py @@ -30,7 +30,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ - --full-libri 1 \ --max-duration 1000 # For streaming model training: @@ -41,7 +40,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --use-fp16 1 \ --exp-dir zipformer/exp \ --causal 1 \ - --full-libri 1 \ --max-duration 1000 It supports training with: @@ -57,7 +55,7 @@ import logging import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union, List import k2 import optim @@ -65,7 +63,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import MLSAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -324,7 +322,7 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - default="data/lang_bpe_500/bpe.model", + default="data/lang_bpe_1000/bpe.model", help="Path to the BPE model", ) @@ -881,7 +879,8 @@ def train_one_epoch( scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, + valid_dls: List[torch.utils.data.DataLoader], + valid_sets: List[str], scaler: GradScaler, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, @@ -1053,22 +1052,26 @@ def train_one_epoch( if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + for valid_set, valid_dl in zip(valid_sets, valid_dls): + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/valid_{valid_set}", params.batch_idx_train + ) logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) + model.train() + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value @@ -1172,35 +1175,16 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriSpeechAsrDataModule(args) - - if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() - - # previously we used the following code to load all training cuts, - # strictly speaking, shuffled training cuts should be used instead, - # but we leave the code here to demonstrate that there is an option - # like this to combine multiple cutsets - - # train_cuts = librispeech.train_clean_100_cuts() - # train_cuts += librispeech.train_clean_360_cuts() - # train_cuts += librispeech.train_other_500_cuts() - else: - train_cuts = librispeech.train_clean_100_cuts() + mls = MLSAsrDataModule(args) + + train_cuts = mls.train_mls_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) return False # In pruned RNN-T, we require that T >= S @@ -1234,22 +1218,28 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + train_dl = mls.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_dls = [] + valid_sets = [] - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + valid_languages = params.language.split(",") + for language in valid_languages: + valid_cuts = mls.mls_dev_cuts(language) + valid_dl = mls.valid_dataloaders(valid_cuts) + valid_dls.append(valid_dl) + valid_sets.append(f"ASR_{language}") + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -1274,7 +1264,8 @@ def run(rank, world_size, args): scheduler=scheduler, sp=sp, train_dl=train_dl, - valid_dl=valid_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, scaler=scaler, tb_writer=tb_writer, world_size=world_size, @@ -1379,7 +1370,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + MLSAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)