From 8f79f6de007883be07fe2d9521c5c4efc57c0ab9 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Wed, 2 Nov 2022 23:36:07 +0800 Subject: [PATCH 1/6] Update tdnn_lstm_ctc.rst (#647) --- docs/source/recipes/aishell/tdnn_lstm_ctc.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/aishell/tdnn_lstm_ctc.rst index 275931698..9eb3b11f7 100644 --- a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/aishell/tdnn_lstm_ctc.rst @@ -498,7 +498,7 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained |aishell asr conformer ctc colab notebook| .. |aishell asr conformer ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg - :target: https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z + :target: https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing **Congratulations!** You have finished the aishell ASR recipe with TDNN-LSTM CTC models in ``icefall``. From 04671b44f85dbd1cb23a93015ccea3fbcd1156f6 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Wed, 2 Nov 2022 23:36:40 +0800 Subject: [PATCH 2/6] Update README.md (#649) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7213d8460..83ce0ac16 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ The WER for this model is: |-----|------------|------------| | WER | 6.59 | 17.69 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) +We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing) #### Transducer: Conformer encoder + LSTM decoder @@ -162,7 +162,7 @@ The CER for this model is: |-----|-------| | CER | 10.16 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z?usp=sharing) +We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing) ### TIMIT From 5d285625cf14f6f1d7108d650ad58856ac0f7cd6 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Wed, 2 Nov 2022 23:37:01 +0800 Subject: [PATCH 3/6] Update tdnn_lstm_ctc.rst (#648) --- docs/source/recipes/librispeech/tdnn_lstm_ctc.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst index ca477fbaa..aa380396a 100644 --- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst @@ -398,7 +398,7 @@ We provide a colab notebook for decoding with pre-trained model. |librispeech tdnn_lstm_ctc colab notebook| .. |librispeech tdnn_lstm_ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg - :target: https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd + :target: https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing **Congratulations!** You have finished the TDNN-LSTM-CTC recipe on librispeech in ``icefall``. From d2a1c65c5cc058e848d949cd99283e1981d499cc Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Thu, 3 Nov 2022 11:27:18 +0900 Subject: [PATCH 4/6] fix torchaudio version in dockerfile (#653) * fix torchaudio version in dockerfile * remove kaldiio --- docker/README.md | 4 ++-- docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile | 5 +++-- docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docker/README.md b/docker/README.md index 0a39b7a49..6f2314e96 100644 --- a/docker/README.md +++ b/docker/README.md @@ -72,14 +72,14 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall ``` ### Tips: -1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/docker}:{/path/in/host/machine}`. +1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. Overall, your docker run command should look like this. ```bash -docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/docker}:{/path/in/host/machine} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 +docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 ``` You can explore more docker run options [here](https://docs.docker.com/engine/reference/commandline/run/) to suit your environment. diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index db4dda864..524303fb8 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -51,8 +51,9 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - -RUN pip install kaldiio graphviz && \ - conda install -y -c pytorch torchaudio +RUN conda install -y -c pytorch torchaudio=0.12 && \ + pip install graphviz + #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile index 7a14a00ad..17a8215f9 100644 --- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile @@ -69,8 +69,8 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - -RUN pip install kaldiio graphviz && \ - conda install -y -c pytorch torchaudio=0.7.1 +RUN conda install -y -c pytorch torchaudio=0.7.1 && \ + pip install graphviz #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ From 163d929601d0130f51bad266f107f9dfbaf6fde4 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 3 Nov 2022 16:29:30 +0800 Subject: [PATCH 5/6] Add fast_beam_search_LG (#622) * Add fast_beam_search_LG * add fast_beam_search_LG to commonly used recipes * fix ci * fix ci * Fix error --- ...pruned-transducer-stateless2-2022-04-29.sh | 1 + ...pruned-transducer-stateless3-2022-04-29.sh | 1 + .../ASR/pruned_transducer_stateless/decode.py | 33 ++++++++++------ .../beam_search.py | 4 +- .../pruned_transducer_stateless2/decode.py | 39 ++++++++++++------- .../pruned_transducer_stateless3/decode.py | 33 ++++++++++------ .../pruned_transducer_stateless4/decode.py | 25 +++++++----- .../pruned_transducer_stateless5/decode.py | 33 ++++++++++------ icefall/utils.py | 7 +++- 9 files changed, 113 insertions(+), 63 deletions(-) diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index ae2bb6822..c3d07dc0e 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -83,4 +83,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless2/exp/*.pt + rm -r data/lang_bpe_500 fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 00580ca1f..22de3b45d 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -82,4 +82,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless3/exp/*.pt + rm -r data/lang_bpe_500 fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 3977f8443..ab23a5a83 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -206,6 +206,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -230,7 +231,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -241,7 +242,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -250,7 +251,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -259,7 +260,7 @@ def get_parser(): "--max-states", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -355,8 +356,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of @@ -387,7 +388,10 @@ def decode_one_batch( ) hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -397,8 +401,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -526,8 +534,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -643,6 +651,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -737,7 +746,7 @@ def main(): model.device = device if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 0004a24eb..4f5016e94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -15,7 +15,7 @@ # limitations under the License. import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional, Union import k2 @@ -727,7 +727,7 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] + timestamp: List[int] = field(default_factory=list) state_cost: Optional[NgramLmStateCost] = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 3b834b919..99d4b5702 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -212,6 +212,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -247,8 +248,8 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. + Used only when --decoding_method is fast_beam_search_LG and + fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores. """, ) @@ -256,7 +257,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -265,7 +266,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -363,9 +364,10 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and + fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -401,7 +403,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -411,8 +416,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -548,9 +557,10 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and + fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -663,6 +673,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -757,7 +768,7 @@ def main(): model.device = device if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 0f30792e3..f34cf1e1f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -202,6 +202,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -226,7 +227,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -237,7 +238,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -246,7 +247,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -255,7 +256,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is, fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -440,8 +441,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. G: Optional. Used only when decoding method is fast_beam_search, @@ -483,7 +484,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -494,8 +498,12 @@ def decode_one_batch( max_states=params.max_states, temperature=params.temperature, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -714,8 +722,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. G: Optional. Used only when decoding method is fast_beam_search, @@ -901,6 +909,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -1002,7 +1011,7 @@ def main(): G = None if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 85097a01a..6afc21ce7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -243,6 +243,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -267,7 +268,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -278,7 +279,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -287,7 +288,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -296,7 +297,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -394,8 +395,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result and timestamps. See above description for the @@ -430,7 +431,10 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -579,8 +583,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -742,6 +746,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -886,7 +891,7 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 632932214..c27d78e34 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -210,6 +210,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -234,7 +235,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -245,7 +246,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -254,7 +255,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -263,7 +264,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -361,8 +362,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of @@ -399,7 +400,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -409,8 +413,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -538,8 +546,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -653,6 +661,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -797,7 +806,7 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/icefall/utils.py b/icefall/utils.py index 45a49fb5c..93dd0b967 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1369,6 +1369,7 @@ def parse_hyp_and_timestamp( - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -1388,6 +1389,7 @@ def parse_hyp_and_timestamp( "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -1400,7 +1402,10 @@ def parse_hyp_and_timestamp( N = len(res.tokens) assert len(res.timestamps) == N use_word_table = False - if decoding_method == "fast_beam_search_nbest_LG": + if ( + decoding_method == "fast_beam_search_nbest_LG" + and decoding_method == "fast_beam_search_LG" + ): assert word_table is not None use_word_table = True From 64aed2cdeb8aae9ae98af8f6211f696fec2ee9d8 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 3 Nov 2022 23:12:35 +0800 Subject: [PATCH 6/6] Fix LG log file name (#657) --- egs/librispeech/ASR/pruned_transducer_stateless/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless3/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless4/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/decode.py | 8 ++++---- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ab23a5a83..7b6338948 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -504,8 +504,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: @@ -675,8 +675,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 99d4b5702..979a0e02e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -527,8 +527,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: @@ -697,8 +697,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index f34cf1e1f..8025d6be1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -686,8 +686,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: return { @@ -936,8 +936,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 6afc21ce7..7003e4764 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -551,8 +551,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: (hyps, timestamps)} else: @@ -770,8 +770,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index c27d78e34..d251246b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -516,8 +516,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: @@ -685,8 +685,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}"