diff --git a/egs/aishell2/ASR/README.md b/egs/aishell2/ASR/README.md new file mode 100644 index 000000000..9f18b8b51 --- /dev/null +++ b/egs/aishell2/ASR/README.md @@ -0,0 +1,19 @@ + +# Introduction + +This recipe includes some different ASR models trained with Aishell2. + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|-----------------------------| +| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner | + +The decoder in `transducer_stateless` is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md new file mode 100644 index 000000000..7ae3b3082 --- /dev/null +++ b/egs/aishell2/ASR/RESULTS.md @@ -0,0 +1,72 @@ +## Results + +### Aishell2 char-based training results (Pruned Transducer 5) + +#### 2022-07-11 + +Using the codes from this commit https://github.com/k2-fsa/icefall/pull/461. + +When training with context size equals to 1, the WERs are + +| | dev-ios | test-ios | comment | +|------------------------------------|-------|----------|----------------------------------| +| greedy search | | | --epoch 10, --avg 2, --max-duration 100 | +| modified beam search (beam size 4) | | | --epoch 10, --avg 2, --max-duration 100 | +| fast beam search (set as default) | | | --epoch 10, --avg 2, --max-duration 1500 | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless5/train.py \ + --world-size 4 \ + --lang-dir data/lang_char \ + --num-epochs 40 \ + --start-epoch 1 \ + --exp-dir /result \ + --max-duration 300 \ + --use-fp16 0 \ + --num-encoder-layers 24 \ + --dim-feedforward 1536 \ + --nhead 8 \ + --encoder-dim 384 \ + --decoder-dim 512 \ + --joiner-dim 512 +``` + +The decoding command is: +``` +for method in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless5/decode.py \ + --epoch 25 \ + --avg 5 \ + --exp-dir /result \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --num-encoder-layers 24 \ + --dim-feedforward 1536 \ + --nhead 8 \ + --encoder-dim 384 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --context-size 1 \ + --use-averaged-model True +done +``` +The tensorboard training log can be found at +https: + +A pre-trained model and decoding logs can be found at + +When training with context size equals to 2, the WERs are + +| | dev-ios | test-ios | comment | +|------------------------------------|-------|----------|----------------------------------| +| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 | +| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 | +| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 | + +The tensorboard training log can be found at +https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index f87d23cc9..eac169fff 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -20,77 +20,32 @@ Usage: (1) greedy search ./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) + --epoch 25 \ + --avg 5 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --lang-dir data/lang_char \ + --max-duration 100 \ + --decoding-method greedy_search +(2) modified beam search ./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search + --epoch 25 \ + --avg 5 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --lang-dir data/lang_char \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 +(3) fast beam search ./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 + --epoch 25 \ + --avg 5 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --lang-dir data/lang_char \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -101,15 +56,11 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AiShell2AsrDataModule from beam_search import ( beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -184,17 +135,10 @@ def get_parser(): help="The experiment dir", ) - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - parser.add_argument( "--lang-dir", type=Path, - default="data/lang_bpe_500", + default="data/lang_char", help="The lang dir containing word table and LG graph", ) @@ -268,7 +212,7 @@ def get_parser(): parser.add_argument( "--context-size", type=int, - default=2, + default=1, help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) @@ -306,9 +250,8 @@ def get_parser(): def decode_one_batch( params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + lexicon: Lexicon, batch: dict, - word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -326,8 +269,6 @@ def decode_one_batch( It's the return value of :func:`get_params`. model: The neural model. - sp: - The BPE model. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -367,51 +308,8 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -421,8 +319,8 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -430,8 +328,8 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) else: batch_size = encoder_out.size(0) @@ -455,7 +353,7 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + hyps.append([lexicon.token_table[idx] for idx in hyp]) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -478,8 +376,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, + lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -491,10 +388,6 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. - sp: - The BPE model. - word_table: - The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, @@ -525,9 +418,8 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, - sp=sp, + lexicon=lexicon, decoding_graph=decoding_graph, - word_table=word_table, batch=batch, ) @@ -535,8 +427,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((ref_text, hyp_words)) results[name].extend(this_batch) @@ -598,7 +489,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AiShell2AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -607,11 +498,7 @@ def main(): assert params.decoding_method in ( "greedy_search", - "beam_search", "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -650,13 +537,10 @@ def main(): logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.unk_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) @@ -744,45 +628,31 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None - word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + aishell2 = AiShell2AsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + valid_cuts = aishell2.valid_cuts() + test_cuts = aishell2.test_cuts() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + # use ios sets for dev and test + dev_dl = aishell2.valid_dataloaders(valid_cuts) + test_dl = aishell2.test_dataloaders(test_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( dl=test_dl, params=params, model=model, - sp=sp, - word_table=word_table, + lexicon=lexicon, decoding_graph=decoding_graph, ) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index 936508900..5274d06c7 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -22,9 +22,9 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 + --lang-dir data/lang_char + --epoch 25 \ + --avg 5 It will generate a file exp_dir/pretrained.pt @@ -34,21 +34,20 @@ you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt - cd /path/to/egs/librispeech/ASR + cd /path/to/egs/aishell2/ASR ./pruned_transducer_stateless5/decode.py \ --exp-dir ./pruned_transducer_stateless5/exp \ --epoch 9999 \ --avg 1 \ --max-duration 600 \ --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --lang-dir data/lang_char """ import argparse import logging from pathlib import Path -import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model @@ -58,6 +57,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import str2bool @@ -115,10 +115,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="The lang dir", ) parser.add_argument( @@ -132,7 +132,7 @@ def get_parser(): parser.add_argument( "--context-size", type=int, - default=2, + default=1, help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) @@ -155,12 +155,10 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.unk_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..32c47abb7 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -20,33 +20,24 @@ Usage: (1) greedy search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --lang-dir ./data/lang_char \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav -(2) beam search +(2) modified beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --lang-dir ./data/lang_char \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ /path/to/bar.wav -(4) fast beam search +(3) fast beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --lang-dir ./data/lang_char \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +57,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +69,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.lexicon import Lexicon + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +87,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - help="""Path to bpe.model.""", + help="""Path to lang. + """, ) parser.add_argument( @@ -165,7 +158,7 @@ def get_parser(): parser.add_argument( "--context-size", type=int, - default=2, + default=1, help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) @@ -216,13 +209,10 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.unk_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(f"{params}") @@ -292,8 +282,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -302,16 +292,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) else: for i in range(num_waves): # fmt: off @@ -332,11 +322,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append([lexicon.token_table[idx] for idx in hyp]) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) + words = "".join(hyp) s += f"{filename}:\n{words}\n\n" logging.info(s)