update the training script

This commit is contained in:
marcoyang 2024-02-27 18:02:33 +08:00
parent a9edd7cc3d
commit 2b2da21208

View File

@ -30,7 +30,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir zipformer/exp \ --exp-dir zipformer/exp \
--full-libri 1 \
--max-duration 1000 --max-duration 1000
# For streaming model training: # For streaming model training:
@ -41,7 +40,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--use-fp16 1 \ --use-fp16 1 \
--exp-dir zipformer/exp \ --exp-dir zipformer/exp \
--causal 1 \ --causal 1 \
--full-libri 1 \
--max-duration 1000 --max-duration 1000
It supports training with: It supports training with:
@ -57,7 +55,7 @@ import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union, List
import k2 import k2
import optim import optim
@ -65,7 +63,7 @@ import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import MLSAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -324,7 +322,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_1000/bpe.model",
help="Path to the BPE model", help="Path to the BPE model",
) )
@ -881,7 +879,8 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dls: List[torch.utils.data.DataLoader],
valid_sets: List[str],
scaler: GradScaler, scaler: GradScaler,
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
@ -1053,22 +1052,26 @@ def train_one_epoch(
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( for valid_set, valid_dl in zip(valid_sets, valid_dls):
params=params, valid_info = compute_validation_loss(
model=model, params=params,
sp=sp, model=model,
valid_dl=valid_dl, sp=sp,
world_size=world_size, valid_dl=valid_dl,
) world_size=world_size,
model.train() )
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, f"train/valid_{valid_set}", params.batch_idx_train
)
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
) )
if tb_writer is not None: model.train()
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value
@ -1172,35 +1175,16 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args) mls = MLSAsrDataModule(args)
if params.full_libri: train_cuts = mls.train_mls_cuts()
train_cuts = librispeech.train_all_shuf_cuts()
# previously we used the following code to load all training cuts,
# strictly speaking, shuffled training cuts should be used instead,
# but we leave the code here to demonstrate that there is an option
# like this to combine multiple cutsets
# train_cuts = librispeech.train_clean_100_cuts()
# train_cuts += librispeech.train_clean_360_cuts()
# train_cuts += librispeech.train_other_500_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 20.0: if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False return False
# In pruned RNN-T, we require that T >= S # In pruned RNN-T, we require that T >= S
@ -1234,22 +1218,28 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = librispeech.train_dataloaders( train_dl = mls.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = librispeech.dev_clean_cuts() valid_dls = []
valid_cuts += librispeech.dev_other_cuts() valid_sets = []
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: valid_languages = params.language.split(",")
scan_pessimistic_batches_for_oom( for language in valid_languages:
model=model, valid_cuts = mls.mls_dev_cuts(language)
train_dl=train_dl, valid_dl = mls.valid_dataloaders(valid_cuts)
optimizer=optimizer, valid_dls.append(valid_dl)
sp=sp, valid_sets.append(f"ASR_{language}")
params=params,
) # if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
@ -1274,7 +1264,8 @@ def run(rank, world_size, args):
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dls=valid_dls,
valid_sets=valid_sets,
scaler=scaler, scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
@ -1379,7 +1370,7 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) MLSAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)