diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index c71a8bfdf..63c26647f 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -19,7 +19,8 @@ The following table lists the differences among them. | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data | -| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + Random combiner| +| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training | +| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + more layers + Random combiner| The decoder in `transducer_stateless` is modified from the paper diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 3143fa077..801375642 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,9 +1,142 @@ ## Results -### LibriSpeech BPE training results (Pruned Transducer 3) +### LibriSpeech BPE training results (Pruned Stateless Transducer 5) + +[pruned_transducer_stateless5](./pruned_transducer_stateless5) + +Same as `Pruned Stateless Transducer 2` but with more layers. + +See + +Note models in `pruned_transducer_stateless` and `pruned_transducer_stateless2` +have about 80 M parameters. + +The notations `large` and `medium` below are from the [Conformer](https://arxiv.org/pdf/2005.08100.pdf) +paper, where the large model has about 188 M parameters and the medium model +has 30.8 M parameters. + +#### Large + +Number of model parameters 118129516 (i.e, 118.13 M). + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|----------------------------------------| +| greedy search (max sym per frame 1) | 2.39 | 5.57 | --epoch 39 --avg 7 --max-duration 600 | +| modified beam search | 2.35 | 5.50 | --epoch 39 --avg 7 --max-duration 600 | +| fast beam search | 2.38 | 5.50 | --epoch 39 --avg 7 --max-duration 600 | + +The training commands are: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless5/train.py \ + --world-size 8 \ + --num-epochs 40 \ + --start-epoch 0 \ + --full-libri 1 \ + --exp-dir pruned_transducer_stateless5/exp-L \ + --max-duration 300 \ + --use-fp16 0 \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 +``` + +The tensorboard log can be found at + + +The decoding commands are: + +```bash +for method in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless5/decode.py \ + --epoch 39 \ + --avg 7 \ + --exp-dir ./pruned_transducer_stateless5/exp-L \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 +done +``` + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + + +#### Medium + +Number of model parameters 30896748 (i.e, 30.9 M). + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|-----------------------------------------| +| greedy search (max sym per frame 1) | 2.88 | 6.69 | --epoch 39 --avg 17 --max-duration 600 | +| modified beam search | 2.83 | 6.59 | --epoch 39 --avg 17 --max-duration 600 | +| fast beam search | 2.83 | 6.61 | --epoch 39 --avg 17 --max-duration 600 | + +The training commands are: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless5/train.py \ + --world-size 8 \ + --num-epochs 40 \ + --start-epoch 0 \ + --full-libri 1 \ + --exp-dir pruned_transducer_stateless5/exp-M \ + --max-duration 300 \ + --use-fp16 0 \ + --num-encoder-layers 18 \ + --dim-feedforward 1024 \ + --nhead 4 \ + --encoder-dim 256 \ + --decoder-dim 512 \ + --joiner-dim 512 +``` + +The tensorboard log can be found at + + +The decoding commands are: + +```bash +for method in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless5/decode.py \ + --epoch 39 \ + --avg 17 \ + --exp-dir ./pruned_transducer_stateless5/exp-M \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --num-encoder-layers 18 \ + --dim-feedforward 1024 \ + --nhead 4 \ + --encoder-dim 256 \ + --decoder-dim 512 \ + --joiner-dim 512 +done +``` + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + + +### LibriSpeech BPE training results (Pruned Stateless Transducer 3) [pruned_transducer_stateless3](./pruned_transducer_stateless3) -Same as `Pruned Transducer 2` but using the XL subset from +Same as `Pruned Stateless Transducer 2` but using the XL subset from [GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data. During training, it selects either a batch from GigaSpeech with prob `giga_prob` @@ -104,6 +237,7 @@ done The following table shows the [Nbest oracle WER](http://kaldi-asr.org/doc/lattices.html#lattices_operations_oracle) for fast beam search. + | epoch | avg | num_paths | nbest_scale | test-clean | test-other | |-------|-----|-----------|-------------|------------|------------| | 27 | 10 | 50 | 0.5 | 0.91 | 2.74 | diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 05a4cdca5..d7d6b1202 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -19,40 +19,40 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method greedy_search (2) beam search (not recommended) ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 (3) modified beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 (4) fast beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -485,7 +485,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index e706083ff..c063b85e1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -18,41 +18,41 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless4/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless4/exp \ - --max-duration 100 \ - --decoding-method greedy_search +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method greedy_search -(2) beam search -./pruned_transducer_stateless4/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless4/exp \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 +(2) beam search (not recommended) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 (3) modified beam search -./pruned_transducer_stateless4/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless4/exp \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 (4) fast beam search -./pruned_transducer_stateless4/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless4/exp \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -69,7 +69,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, - fast_beam_search, + fast_beam_search_one_best, greedy_search, greedy_search_batch, modified_beam_search, @@ -98,33 +98,34 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + "'--epoch' and '--iter'", ) parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless4/exp", + default="pruned_transducer_stateless5/exp", help="The experiment dir", ) @@ -151,7 +152,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="""An interger indicating how many candidates we will keep for each + help="""An integer indicating how many candidates we will keep for each frame. Used only when --decoding-method is beam_search or modified_beam_search.""", ) @@ -253,7 +254,7 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search( + hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -271,6 +272,7 @@ def decode_one_batch( hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -278,6 +280,7 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens): @@ -357,9 +360,9 @@ def decode_dataset( num_batches = "?" if params.decoding_method == "greedy_search": - log_interval = 100 + log_interval = 50 else: - log_interval = 2 + log_interval = 10 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -455,13 +458,19 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -478,8 +487,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -487,8 +497,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index b5757ee8c..480c47430 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -20,23 +20,23 @@ # to a single one using model averaging. """ Usage: -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ +./pruned_transducer_stateless5/export.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 20 \ --avg 10 It will generate a file exp_dir/pretrained.pt -To use the generated file with `pruned_transducer_stateless2/decode.py`, +To use the generated file with `pruned_transducer_stateless5/decode.py`, you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ + ./pruned_transducer_stateless5/decode.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ --epoch 9999 \ --avg 1 \ --max-duration 100 \ @@ -49,7 +49,7 @@ from pathlib import Path import sentencepiece as spm import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -80,7 +80,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless5/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, @@ -109,6 +109,8 @@ def get_parser(): "2 means tri-gram", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index b759e77ab..0c6a693e7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -21,22 +21,22 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless4/train.py \ +./pruned_transducer_stateless5/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless4/exp \ + --exp-dir pruned_transducer_stateless5/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless4/train.py \ +./pruned_transducer_stateless5/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --use_fp16 1 \ - --exp-dir pruned_transducer_stateless4/exp \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless5/exp \ --full-libri 1 \ --max-duration 550 @@ -185,7 +185,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless4/exp", + default="pruned_transducer_stateless5/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -712,25 +712,29 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise if params.print_diagnostics and batch_idx == 5: return @@ -975,6 +979,38 @@ def run(rank, world_size, args): cleanup_dist() +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + def scan_pessimistic_batches_for_oom( model: nn.Module, train_dl: torch.utils.data.DataLoader, @@ -1006,7 +1042,7 @@ def scan_pessimistic_batches_for_oom( loss.backward() optimizer.step() optimizer.zero_grad() - except RuntimeError as e: + except Exception as e: if "CUDA out of memory" in str(e): logging.error( "Your GPU ran out of memory with the current " @@ -1015,6 +1051,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) + display_and_save_batch(batch, params=params, sp=sp) raise