mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
update the training script
This commit is contained in:
parent
a9edd7cc3d
commit
2b2da21208
@ -30,7 +30,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 1000
|
||||
|
||||
# For streaming model training:
|
||||
@ -41,7 +40,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp \
|
||||
--causal 1 \
|
||||
--full-libri 1 \
|
||||
--max-duration 1000
|
||||
|
||||
It supports training with:
|
||||
@ -57,7 +55,7 @@ import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -65,7 +63,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from asr_datamodule import MLSAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
@ -324,7 +322,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
default="data/lang_bpe_1000/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
@ -881,7 +879,8 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
valid_dls: List[torch.utils.data.DataLoader],
|
||||
valid_sets: List[str],
|
||||
scaler: GradScaler,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
@ -1053,22 +1052,26 @@ def train_one_epoch(
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
for valid_set, valid_dl in zip(valid_sets, valid_dls):
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, f"train/valid_{valid_set}", params.batch_idx_train
|
||||
)
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
model.train()
|
||||
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
@ -1172,35 +1175,16 @@ def run(rank, world_size, args):
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.full_libri:
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
|
||||
# previously we used the following code to load all training cuts,
|
||||
# strictly speaking, shuffled training cuts should be used instead,
|
||||
# but we leave the code here to demonstrate that there is an option
|
||||
# like this to combine multiple cutsets
|
||||
|
||||
# train_cuts = librispeech.train_clean_100_cuts()
|
||||
# train_cuts += librispeech.train_clean_360_cuts()
|
||||
# train_cuts += librispeech.train_other_500_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
mls = MLSAsrDataModule(args)
|
||||
|
||||
train_cuts = mls.train_mls_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
@ -1234,22 +1218,28 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_dl = mls.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
valid_dls = []
|
||||
valid_sets = []
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
valid_languages = params.language.split(",")
|
||||
for language in valid_languages:
|
||||
valid_cuts = mls.mls_dev_cuts(language)
|
||||
valid_dl = mls.valid_dataloaders(valid_cuts)
|
||||
valid_dls.append(valid_dl)
|
||||
valid_sets.append(f"ASR_{language}")
|
||||
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -1274,7 +1264,8 @@ def run(rank, world_size, args):
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
valid_dls=valid_dls,
|
||||
valid_sets=valid_sets,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
@ -1379,7 +1370,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
MLSAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user