mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
update the training script
This commit is contained in:
parent
a9edd7cc3d
commit
2b2da21208
@ -30,7 +30,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 1000
|
--max-duration 1000
|
||||||
|
|
||||||
# For streaming model training:
|
# For streaming model training:
|
||||||
@ -41,7 +40,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
--causal 1 \
|
--causal 1 \
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 1000
|
--max-duration 1000
|
||||||
|
|
||||||
It supports training with:
|
It supports training with:
|
||||||
@ -57,7 +55,7 @@ import logging
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -65,7 +63,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import MLSAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -324,7 +322,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_1000/bpe.model",
|
||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -881,7 +879,8 @@ def train_one_epoch(
|
|||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dls: List[torch.utils.data.DataLoader],
|
||||||
|
valid_sets: List[str],
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
@ -1053,22 +1052,26 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
for valid_set, valid_dl in zip(valid_sets, valid_dls):
|
||||||
params=params,
|
valid_info = compute_validation_loss(
|
||||||
model=model,
|
params=params,
|
||||||
sp=sp,
|
model=model,
|
||||||
valid_dl=valid_dl,
|
sp=sp,
|
||||||
world_size=world_size,
|
valid_dl=valid_dl,
|
||||||
)
|
world_size=world_size,
|
||||||
model.train()
|
)
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
|
||||||
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
valid_info.write_summary(
|
||||||
|
tb_writer, f"train/valid_{valid_set}", params.batch_idx_train
|
||||||
|
)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
)
|
)
|
||||||
if tb_writer is not None:
|
model.train()
|
||||||
valid_info.write_summary(
|
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
@ -1172,35 +1175,16 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
mls = MLSAsrDataModule(args)
|
||||||
|
|
||||||
if params.full_libri:
|
train_cuts = mls.train_mls_cuts()
|
||||||
train_cuts = librispeech.train_all_shuf_cuts()
|
|
||||||
|
|
||||||
# previously we used the following code to load all training cuts,
|
|
||||||
# strictly speaking, shuffled training cuts should be used instead,
|
|
||||||
# but we leave the code here to demonstrate that there is an option
|
|
||||||
# like this to combine multiple cutsets
|
|
||||||
|
|
||||||
# train_cuts = librispeech.train_clean_100_cuts()
|
|
||||||
# train_cuts += librispeech.train_clean_360_cuts()
|
|
||||||
# train_cuts += librispeech.train_other_500_cuts()
|
|
||||||
else:
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
#
|
|
||||||
# Caution: There is a reason to select 20.0 here. Please see
|
|
||||||
# ../local/display_manifest_statistics.py
|
|
||||||
#
|
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 1.0 or c.duration > 20.0:
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
# logging.warning(
|
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
|
||||||
# )
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
# In pruned RNN-T, we require that T >= S
|
||||||
@ -1234,22 +1218,28 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(
|
train_dl = mls.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_dls = []
|
||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_sets = []
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
valid_languages = params.language.split(",")
|
||||||
scan_pessimistic_batches_for_oom(
|
for language in valid_languages:
|
||||||
model=model,
|
valid_cuts = mls.mls_dev_cuts(language)
|
||||||
train_dl=train_dl,
|
valid_dl = mls.valid_dataloaders(valid_cuts)
|
||||||
optimizer=optimizer,
|
valid_dls.append(valid_dl)
|
||||||
sp=sp,
|
valid_sets.append(f"ASR_{language}")
|
||||||
params=params,
|
|
||||||
)
|
# if not params.print_diagnostics:
|
||||||
|
# scan_pessimistic_batches_for_oom(
|
||||||
|
# model=model,
|
||||||
|
# train_dl=train_dl,
|
||||||
|
# optimizer=optimizer,
|
||||||
|
# sp=sp,
|
||||||
|
# params=params,
|
||||||
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -1274,7 +1264,8 @@ def run(rank, world_size, args):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dls=valid_dls,
|
||||||
|
valid_sets=valid_sets,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
@ -1379,7 +1370,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
MLSAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user