From 1c4987f6e657457a7ac95093253e34d5ed5be4ff Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Tue, 14 Jun 2022 11:44:03 +0800 Subject: [PATCH] update codes for merging --- README.md | 18 +++ egs/aishell4/ASR/README.md | 19 +++ egs/aishell4/ASR/RESULTS.md | 117 ++++++++++++++++++ .../asr_datamodule.py | 4 +- .../pruned_transducer_stateless5/decode.py | 45 +++---- .../pruned_transducer_stateless5/export.py | 26 ++-- .../pretrained.py | 90 +++++++------- .../test_model.py | 4 +- .../ASR/pruned_transducer_stateless5/train.py | 4 +- 9 files changed, 241 insertions(+), 86 deletions(-) create mode 100644 egs/aishell4/ASR/README.md create mode 100644 egs/aishell4/ASR/RESULTS.md diff --git a/README.md b/README.md index 9f8db554c..635cc3c7a 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ We provide the following recipes: - [Aidatatang_200zh][aidatatang_200zh] - [WenetSpeech][wenetspeech] - [Alimeeting][alimeeting] + - [Aishell4][aishell4] ### yesno @@ -262,6 +263,21 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing) +### Aishell4 + +We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5]. + +#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets) + +The best CER(%) results: +| | test | +|----------------------|--------| +| greedy search | 29.89 | +| fast beam search | 28.91 | +| modified beam search | 29.08 | + +We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) + ## Deployment with C++ Once you have trained a model in icefall, you may want to deploy it with C++, @@ -290,6 +306,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2 [WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2 [Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2 +[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5 [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR [aishell]: egs/aishell/ASR @@ -299,5 +316,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [aidatatang_200zh]: egs/aidatatang_200zh/ASR [wenetspeech]: egs/wenetspeech/ASR [alimeeting]: egs/alimeeting/ASR +[aishell4]: egs/aishell4/ASR [k2]: https://github.com/k2-fsa/k2 ) diff --git a/egs/aishell4/ASR/README.md b/egs/aishell4/ASR/README.md new file mode 100644 index 000000000..3744032f8 --- /dev/null +++ b/egs/aishell4/ASR/README.md @@ -0,0 +1,19 @@ + +# Introduction + +This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets). + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|-----------------------------| +| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | + +The decoder in `transducer_stateless` is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/aishell4/ASR/RESULTS.md b/egs/aishell4/ASR/RESULTS.md new file mode 100644 index 000000000..9bd062f1d --- /dev/null +++ b/egs/aishell4/ASR/RESULTS.md @@ -0,0 +1,117 @@ +## Results + +### Aishell4 Char training results (Pruned Transducer Stateless5) + +#### 2022-06-13 + +Using the codes from this PR https://github.com/k2-fsa/icefall/pull/399. + +When use-averaged-model=False, the CERs are +| | test | comment | +|------------------------------------|------------|------------------------------------------| +| greedy search | 30.05 | --epoch 30, --avg 25, --max-duration 800 | +| modified beam search (beam size 4) | 29.16 | --epoch 30, --avg 25, --max-duration 800 | +| fast beam search (set as default) | 29.20 | --epoch 30, --avg 25, --max-duration 1500| + +When use-averaged-model=True, the CERs are +| | test | comment | +|------------------------------------|------------|----------------------------------------------------------------------| +| greedy search | 29.89 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True | +| modified beam search (beam size 4) | 28.91 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True | +| fast beam search (set as default) | 29.08 | --iter 36000, --avg 8, --max-duration 1500 --use-averaged-model=True | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless5/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless5/exp \ + --lang-dir data/lang_char \ + --max-duration 220 \ + --save-every-n 4000 + +``` + +The tensorboard training log can be found at +https://tensorboard.dev/experiment/tjaVRKERS8C10SzhpBcxSQ/#scalars + +When use-averaged-model=False, the decoding command is: +``` +epoch=30 +avg=25 + +## greedy search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 800 + +## modified beam search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 800 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +## fast beam search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +``` + +When use-averaged-model=True, the decoding command is: +``` +iter=36000 +avg=8 + +## greedy search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 800 \ + --use-averaged-model True + +## modified beam search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 800 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --use-averaged-model True + +## fast beam search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --use-averaged-model True +``` + +A pre-trained model and decoding logs can be found at diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 9dcd6fa4b..7aa53ddda 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, @@ -222,7 +222,7 @@ class Aishell4AsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest_lazy( + cuts_musan = load_manifest( self.args.manifest_dir / "musan_cuts.jsonl.gz" ) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 619534519..705e34647 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -17,43 +18,37 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Usage: +When use-averaged-model=True, usage: (1) greedy search ./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ + --iter 36000 \ + --avg 8 \ --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method greedy_search + --max-duration 800 \ + --decoding-method greedy_search \ + --use-averaged-model True -(2) beam search (not recommended) +(2) modified beam search ./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ + --iter 36000 \ + --avg 8 \ --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ + --max-duration 800 \ --decoding-method modified_beam_search \ - --beam-size 4 + --beam-size 4 \ + --use-averaged-model True -(4) fast beam search +(3) fast beam search ./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ - --avg 15 \ + --iter 36000 \ + --avg 8 \ --exp-dir ./pruned_transducer_stateless5/exp \ - --max-duration 600 \ + --max-duration 800 \ --decoding-method fast_beam_search \ --beam 4 \ --max-contexts 4 \ - --max-states 8 + --max-states 8 \ + --use-averaged-model True """ diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index f1269a4bd..f487a8ba5 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --lang-dir data/lang_char \ --epoch 20 \ --avg 10 @@ -34,21 +34,20 @@ you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt - cd /path/to/egs/librispeech/ASR + cd /path/to/egs/aishell4/ASR ./pruned_transducer_stateless5/decode.py \ --exp-dir ./pruned_transducer_stateless5/exp \ --epoch 9999 \ --avg 1 \ --max-duration 600 \ --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --lang-dir data/lang_char """ import argparse import logging from pathlib import Path -import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model @@ -58,6 +57,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import str2bool @@ -115,10 +115,13 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, ) parser.add_argument( @@ -157,12 +160,9 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..1fa893637 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -15,30 +15,33 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Usage: +When use-averaged-model=True, usage: (1) greedy search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ + --lang-dir data/lang_char \ + --decoding-method greedy_search \ + --use-averaged-model True \ /path/to/foo.wav \ /path/to/bar.wav (2) beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ + --lang-dir data/lang_char \ + --use-averaged-model True \ + --decoding-method beam_search \ --beam-size 4 \ /path/to/foo.wav \ /path/to/bar.wav -(3) modified beam search +(3) modified beam search (not suggest) ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ + --lang-dir data/lang_char \ + --use-averaged-model True \ + --decoding-method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ /path/to/bar.wav @@ -46,8 +49,9 @@ Usage: (4) fast beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ + --lang-dir data/lang_char \ + --use-averaged-model True \ + --decoding-method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ /path/to/bar.wav @@ -66,7 +70,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +82,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.lexicon import Lexicon + def get_parser(): parser = argparse.ArgumentParser( @@ -95,13 +100,14 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - help="""Path to bpe.model.""", + help="""Path to lang. + """, ) parser.add_argument( - "--method", + "--decoding-method", type=str, default="greedy_search", help="""Possible values are: @@ -134,7 +140,7 @@ def get_parser(): type=int, default=4, help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or + frame. Used only when --decoding-method is beam_search or modified_beam_search.""", ) @@ -145,21 +151,21 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --method is fast_beam_search""", + Used only when --decoding-method is fast_beam_search""", ) parser.add_argument( "--max-contexts", type=int, default=4, - help="""Used only when --method is fast_beam_search""", + help="""Used only when --decoding-method is fast_beam_search""", ) parser.add_argument( "--max-states", type=int, default=8, - help="""Used only when --method is fast_beam_search""", + help="""Used only when --decoding-method is fast_beam_search""", ) parser.add_argument( @@ -174,7 +180,7 @@ def get_parser(): type=int, default=1, help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. + --decoding-method is greedy_search. """, ) @@ -216,13 +222,9 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(f"{params}") @@ -276,12 +278,12 @@ def main(): num_waves = encoder_out.size(0) hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": + msg = f"Using {params.decoding_method}" + if params.decoding_method == "beam_search": msg += f" with beam size {params.beam_size}" logging.info(msg) - if params.method == "fast_beam_search": + if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( model=model, @@ -292,9 +294,9 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, @@ -302,37 +304,41 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) else: for i in range(num_waves): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on - if params.method == "greedy_search": + if params.decoding_method == "greedy_search": hyp = greedy_search( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, ) - elif params.method == "beam_search": + elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, ) else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) + raise ValueError( + f"Unsupported decoding-method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py b/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py index 9aad32014..d42c3b4f4 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py @@ -19,8 +19,8 @@ """ To run this file, do: - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless4/test_model.py + cd icefall/egs/aishell4/ASR + python ./pruned_transducer_stateless5/test_model.py """ from train import get_params, get_transducer_model diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index c2cf5aa66..0a48b9059 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) +# Mingshuang Luo, # Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors @@ -396,7 +396,7 @@ def get_params() -> AttributeDict: "feature_dim": 80, "subsampling_factor": 4, # parameters for Noam - "model_warm_step": 50, # arg given to model, not for lrate + "model_warm_step": 400, # arg given to model, not for lrate "env_info": get_env_info(), } )