diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py index e44b54b9d..fa1b9d8e1 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py @@ -108,6 +108,27 @@ def add_model_arguments(parser: argparse.ArgumentParser): default=512, help="Encoder output dimesion.", ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Decoder output dimension.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="Joiner output dimension.", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Dimension of feed forward.", + ) parser.add_argument( "--rnn-hidden-size", @@ -402,11 +423,6 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), @@ -426,6 +442,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder @@ -619,6 +636,7 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] + #import pdb; pdb.set_trace() y = graph_compiler.texts_to_ids_with_bpe(texts) if type(y) == list: y = k2.RaggedTensor(y).to(device) @@ -910,8 +928,6 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 800 fix_random_seed(params.seed) if world_size > 1: @@ -930,6 +946,12 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") + bpe_model = params.lang_dir + "/bpe.model" + import sentencepiece as spm + + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + lexicon = Lexicon(params.lang_dir) graph_compiler = CharCtcTrainingGraphCompiler( lexicon=lexicon, @@ -1001,33 +1023,7 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./lstm.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 3) // 2 - 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True + return 1.0 <= c.duration <= 20.0 def text_normalize_for_cut(c: Cut): # Text normalize for each sample @@ -1056,15 +1052,15 @@ def run(rank, world_size, args): valid_cuts = valid_cuts.map(text_normalize_for_cut) valid_dl = tal_csasr.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # graph_compiler=graph_compiler, + # params=params, + # warmup=0.0 if params.start_epoch == 1 else 1.0, + # ) scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: