diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py index f56b4fd83..e44b54b9d 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -52,17 +53,18 @@ from typing import Any, Dict, Optional, Tuple, Union import k2 import optim -import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import TAL_CSASRAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from lstm import RNN +from local.text_normalize import text_normalize +from local.tokenize_with_bpe_model import tokenize_by_bpe_model from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -71,6 +73,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -79,6 +82,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, @@ -188,7 +192,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="lstm_transducer_stateless/exp", + default="lstm_transducer_stateless3/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -196,10 +200,13 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, ) parser.add_argument( @@ -579,7 +586,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, + graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, warmup: float = 1.0, @@ -612,9 +619,11 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - + y = graph_compiler.texts_to_ids_with_bpe(texts) + if type(y) == list: + y = k2.RaggedTensor(y).to(device) + else: + y = y.to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, @@ -690,7 +699,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, + graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -703,7 +712,7 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, - sp=sp, + graph_compiler=graph_compiler, batch=batch, is_training=False, ) @@ -726,7 +735,7 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, + graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -770,7 +779,12 @@ def train_one_epoch( tot_loss = MetricsTracker() + cur_batch_idx = params.get("cur_batch_idx", 0) + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -779,7 +793,7 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - sp=sp, + graph_compiler=graph_compiler, batch=batch, is_training=True, warmup=(params.batch_idx_train / params.model_warm_step), @@ -790,7 +804,6 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() @@ -817,6 +830,7 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): + params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -829,6 +843,7 @@ def train_one_epoch( scaler=scaler, rank=rank, ) + del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -863,7 +878,7 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - sp=sp, + graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) @@ -915,12 +930,14 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) @@ -967,14 +984,13 @@ def run(rank, world_size, args): # print(scheduler.base_lrs) if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) - librispeech = LibriSpeechAsrDataModule(args) - - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + tal_csasr = TAL_CSASRAsrDataModule(args) + train_cuts = tal_csasr.train_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1012,8 +1028,18 @@ def run(rank, world_size, args): return False return True - + + def text_normalize_for_cut(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = text.strip("\n").strip("\t") + text = text_normalize(text) + text = tokenize_by_bpe_model(sp, text) + c.supervisions[0].text = text + return c + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.map(text_normalize_for_cut) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1022,20 +1048,20 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + train_dl = tal_csasr.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_cuts = tal_csasr.valid_cuts() + valid_cuts = valid_cuts.map(text_normalize_for_cut) + valid_dl = tal_csasr.valid_dataloaders(valid_cuts) if not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, optimizer=optimizer, - sp=sp, + graph_compiler=graph_compiler, params=params, warmup=0.0 if params.start_epoch == 1 else 1.0, ) @@ -1061,7 +1087,7 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - sp=sp, + graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1096,7 +1122,7 @@ def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, + graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, warmup: float, ): @@ -1113,7 +1139,7 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, - sp=sp, + graph_compiler=graph_compiler, batch=batch, is_training=True, warmup=warmup, @@ -1135,7 +1161,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + TAL_CSASRAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)