modify train.py

This commit is contained in:
marcoyang 2023-02-10 16:49:57 +08:00
parent d1a0668f68
commit 7556811d64

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,)
# Zengwei Yao)
# Mingshuang Luo
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -52,17 +53,18 @@ from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import TAL_CSASRAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from lstm import RNN
from local.text_normalize import text_normalize
from local.tokenize_with_bpe_model import tokenize_by_bpe_model
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
@ -71,6 +73,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -79,6 +82,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
@ -188,7 +192,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless/exp",
default="lstm_transducer_stateless3/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -196,10 +200,13 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
@ -579,7 +586,7 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
warmup: float = 1.0,
@ -612,9 +619,11 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
y = graph_compiler.texts_to_ids_with_bpe(texts)
if type(y) == list:
y = k2.RaggedTensor(y).to(device)
else:
y = y.to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,
@ -690,7 +699,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@ -703,7 +712,7 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
@ -726,7 +735,7 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@ -770,7 +779,12 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -779,7 +793,7 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
@ -790,7 +804,6 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
@ -817,6 +830,7 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -829,6 +843,7 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -863,7 +878,7 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
@ -915,12 +930,14 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@ -967,14 +984,13 @@ def run(rank, world_size, args):
# print(scheduler.base_lrs)
if params.print_diagnostics:
diagnostic = diagnostics.attach_diagnostics(model)
opts = diagnostics.TensorDiagnosticOptions(
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
tal_csasr = TAL_CSASRAsrDataModule(args)
train_cuts = tal_csasr.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1012,8 +1028,18 @@ def run(rank, world_size, args):
return False
return True
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
text = text_normalize(text)
text = tokenize_by_bpe_model(sp, text)
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(text_normalize_for_cut)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1022,20 +1048,20 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_dl = tal_csasr.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
valid_cuts = tal_csasr.valid_cuts()
valid_cuts = valid_cuts.map(text_normalize_for_cut)
valid_dl = tal_csasr.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
graph_compiler=graph_compiler,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
@ -1061,7 +1087,7 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@ -1096,7 +1122,7 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
warmup: float,
):
@ -1113,7 +1139,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
warmup=warmup,
@ -1135,7 +1161,7 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
TAL_CSASRAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)