modify train.py

This commit is contained in:
marcoyang 2023-02-10 16:49:57 +08:00
parent d1a0668f68
commit 7556811d64

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang, # Wei Kang,
# Mingshuang Luo,) # Mingshuang Luo
# Zengwei Yao) # Zengwei Yao,
# Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -52,17 +53,18 @@ from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim import optim
import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import TAL_CSASRAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from lstm import RNN from lstm import RNN
from local.text_normalize import text_normalize
from local.tokenize_with_bpe_model import tokenize_by_bpe_model
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
@ -71,6 +73,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -79,6 +82,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
@ -188,7 +192,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="lstm_transducer_stateless/exp", default="lstm_transducer_stateless3/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -196,10 +200,13 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
) )
parser.add_argument( parser.add_argument(
@ -579,7 +586,7 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float = 1.0, warmup: float = 1.0,
@ -612,9 +619,11 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = graph_compiler.texts_to_ids_with_bpe(texts)
y = k2.RaggedTensor(y).to(device) if type(y) == list:
y = k2.RaggedTensor(y).to(device)
else:
y = y.to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss = model(
x=feature, x=feature,
@ -690,7 +699,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> MetricsTracker: ) -> MetricsTracker:
@ -703,7 +712,7 @@ def compute_validation_loss(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=False, is_training=False,
) )
@ -726,7 +735,7 @@ def train_one_epoch(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
@ -770,7 +779,12 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -779,7 +793,7 @@ def train_one_epoch(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step), warmup=(params.batch_idx_train / params.model_warm_step),
@ -790,7 +804,6 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
@ -817,6 +830,7 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -829,6 +843,7 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,
@ -863,7 +878,7 @@ def train_one_epoch(
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
@ -915,12 +930,14 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
# <blk> is defined in local/train_bpe_model.py params.blank_id = lexicon.token_table["<blk>"]
params.blank_id = sp.piece_to_id("<blk>") params.vocab_size = max(lexicon.tokens) + 1
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -967,14 +984,13 @@ def run(rank, world_size, args):
# print(scheduler.base_lrs) # print(scheduler.base_lrs)
if params.print_diagnostics: if params.print_diagnostics:
diagnostic = diagnostics.attach_diagnostics(model) opts = diagnostics.TensorDiagnosticOptions(
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) tal_csasr = TAL_CSASRAsrDataModule(args)
train_cuts = tal_csasr.train_cuts()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
@ -1012,8 +1028,18 @@ def run(rank, world_size, args):
return False return False
return True return True
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
text = text_normalize(text)
text = tokenize_by_bpe_model(sp, text)
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(text_normalize_for_cut)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
@ -1022,20 +1048,20 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = librispeech.train_dataloaders( train_dl = tal_csasr.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = tal_csasr.valid_cuts()
valid_cuts += librispeech.dev_other_cuts() valid_cuts = valid_cuts.map(text_normalize_for_cut)
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = tal_csasr.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
optimizer=optimizer, optimizer=optimizer,
sp=sp, graph_compiler=graph_compiler,
params=params, params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
@ -1061,7 +1087,7 @@ def run(rank, world_size, args):
model_avg=model_avg, model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sp=sp, graph_compiler=graph_compiler,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
@ -1096,7 +1122,7 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict, params: AttributeDict,
warmup: float, warmup: float,
): ):
@ -1113,7 +1139,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=warmup, warmup=warmup,
@ -1135,7 +1161,7 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) TAL_CSASRAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)