diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 86ec6172f..ad76411c0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -42,6 +42,17 @@ Usage: --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -49,16 +60,26 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + modified_beam_search, +) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -88,6 +109,17 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -110,6 +142,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - fast_beam_search """, ) @@ -117,8 +150,35 @@ def get_parser(): "--beam-size", type=int, default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, help="""Used only when --decoding-method is - beam_search or modified_beam_search""", + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", ) parser.add_argument( @@ -144,6 +204,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -166,6 +227,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -184,36 +248,62 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) hyps = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } else: - return {f"beam_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": hyps} def decode_dataset( @@ -221,6 +311,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -233,6 +324,9 @@ def decode_dataset( The neural model. sp: The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -260,6 +354,7 @@ def decode_dataset( params=params, model=model, sp=sp, + decoding_graph=decoding_graph, batch=batch, ) @@ -340,12 +435,17 @@ def main(): assert params.decoding_method in ( "greedy_search", "beam_search", + "fast_beam_search", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "beam_search" in params.decoding_method: + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" @@ -372,7 +472,12 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg == 1: + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 @@ -388,6 +493,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -408,6 +518,7 @@ def main(): params=params, model=model, sp=sp, + decoding_graph=decoding_graph, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 47a519dc9..13e45e03b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -64,6 +64,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size + self.vocab_size = vocab_size if context_size > 1: self.conv = ScaledConv1d( in_channels=embedding_dim, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 851822aae..d28a8a060 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -36,7 +36,7 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import k2 import sentencepiece as spm @@ -48,6 +48,7 @@ from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor @@ -55,8 +56,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall import diagnostics @@ -112,6 +114,15 @@ def get_parser(): """, ) + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -192,6 +203,30 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + return parser @@ -320,15 +355,16 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: +) -> Optional[Dict[str, Any]]: """Load checkpoint from file. - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, and `best_valid_loss` in `params`. Args: @@ -338,20 +374,22 @@ def load_checkpoint_if_available( The training model. optimizer: The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. Returns: - Return None. + Return a dict containing previously saved training info. """ - if params.start_epoch <= 0: - return + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( filename, model=model, optimizer=optimizer, - scheduler=scheduler, ) keys = [ @@ -360,10 +398,13 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", + "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] + params["start_epoch"] = saved_params["cur_epoch"] + return saved_params @@ -371,7 +412,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -381,6 +422,10 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. """ if rank != 0: return @@ -390,7 +435,7 @@ def save_checkpoint( model=model, params=params, optimizer=optimizer, - scheduler=scheduler, + sampler=sampler, rank=rank, ) @@ -509,6 +554,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -531,12 +577,21 @@ def train_one_epoch( Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. """ model.train() tot_loss = MetricsTracker() + cur_batch_idx = params.get("cur_batch_idx", 0) + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -560,6 +615,27 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + sampler=train_dl.sampler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " @@ -688,7 +764,14 @@ def run(rank, world_size, args): train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = librispeech.train_dataloaders(train_cuts) + if checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() @@ -728,6 +811,7 @@ def run(rank, world_size, args): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) if params.print_diagnostics: @@ -738,6 +822,7 @@ def run(rank, world_size, args): params=params, model=model, optimizer=optimizer, + sampler=train_dl.sampler, rank=rank, ) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index fa9b98fa0..06eacd736 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -135,8 +135,13 @@ def get_diagnostics_for_dim( return "" count = sum(counts) stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: + print("Error getting eigenvalues, trying another method") + eigs, _ = torch.eigs(stats) + stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: stats = torch.stack(stats).sum(dim=0)