update the training script

This commit is contained in:
marcoyang 2024-02-27 18:02:33 +08:00
parent a9edd7cc3d
commit 2b2da21208

View File

@ -30,7 +30,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--full-libri 1 \
--max-duration 1000
# For streaming model training:
@ -41,7 +40,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--full-libri 1 \
--max-duration 1000
It supports training with:
@ -57,7 +55,7 @@ import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union, List
import k2
import optim
@ -65,7 +63,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import MLSAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -324,7 +322,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
default="data/lang_bpe_1000/bpe.model",
help="Path to the BPE model",
)
@ -881,7 +879,8 @@ def train_one_epoch(
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
valid_dls: List[torch.utils.data.DataLoader],
valid_sets: List[str],
scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
@ -1053,22 +1052,26 @@ def train_one_epoch(
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
for valid_set, valid_dl in zip(valid_sets, valid_dls):
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, f"train/valid_{valid_set}", params.batch_idx_train
)
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
model.train()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -1172,35 +1175,16 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args)
if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
# previously we used the following code to load all training cuts,
# strictly speaking, shuffled training cuts should be used instead,
# but we leave the code here to demonstrate that there is an option
# like this to combine multiple cutsets
# train_cuts = librispeech.train_clean_100_cuts()
# train_cuts += librispeech.train_clean_360_cuts()
# train_cuts += librispeech.train_other_500_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
mls = MLSAsrDataModule(args)
train_cuts = mls.train_mls_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
# In pruned RNN-T, we require that T >= S
@ -1234,22 +1218,28 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_dl = mls.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
valid_dls = []
valid_sets = []
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
valid_languages = params.language.split(",")
for language in valid_languages:
valid_cuts = mls.mls_dev_cuts(language)
valid_dl = mls.valid_dataloaders(valid_cuts)
valid_dls.append(valid_dl)
valid_sets.append(f"ASR_{language}")
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
@ -1274,7 +1264,8 @@ def run(rank, world_size, args):
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
valid_dls=valid_dls,
valid_sets=valid_sets,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
@ -1379,7 +1370,7 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
MLSAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)