resolve conflicts

This commit is contained in:
marcoyang 2022-11-04 11:37:42 +08:00
commit a2d7095c1c
14 changed files with 142 additions and 91 deletions

View File

@ -83,4 +83,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
done
rm pruned_transducer_stateless2/exp/*.pt
rm -r data/lang_bpe_500
fi

View File

@ -82,4 +82,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
done
rm pruned_transducer_stateless3/exp/*.pt
rm -r data/lang_bpe_500
fi

View File

@ -82,7 +82,7 @@ The WER for this model is:
|-----|------------|------------|
| WER | 6.59 | 17.69 |
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
#### Transducer: Conformer encoder + LSTM decoder
@ -162,7 +162,7 @@ The CER for this model is:
|-----|-------|
| CER | 10.16 |
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z?usp=sharing)
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
### TIMIT

View File

@ -72,14 +72,14 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
```
### Tips:
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/docker}:{/path/in/host/machine}`.
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
Overall, your docker run command should look like this.
```bash
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/docker}:{/path/in/host/machine} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
```
You can explore more docker run options [here](https://docs.docker.com/engine/reference/commandline/run/) to suit your environment.

View File

@ -51,8 +51,9 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
RUN pip install kaldiio graphviz && \
conda install -y -c pytorch torchaudio
RUN conda install -y -c pytorch torchaudio=0.12 && \
pip install graphviz
#install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \

View File

@ -69,8 +69,8 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
RUN pip install kaldiio graphviz && \
conda install -y -c pytorch torchaudio=0.7.1
RUN conda install -y -c pytorch torchaudio=0.7.1 && \
pip install graphviz
#install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \

View File

@ -498,7 +498,7 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained
|aishell asr conformer ctc colab notebook|
.. |aishell asr conformer ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z
:target: https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing
**Congratulations!** You have finished the aishell ASR recipe with
TDNN-LSTM CTC models in ``icefall``.

View File

@ -398,7 +398,7 @@ We provide a colab notebook for decoding with pre-trained model.
|librispeech tdnn_lstm_ctc colab notebook|
.. |librispeech tdnn_lstm_ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd
:target: https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing
**Congratulations!** You have finished the TDNN-LSTM-CTC recipe on librispeech in ``icefall``.

View File

@ -206,6 +206,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -230,7 +231,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
@ -241,7 +242,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -250,7 +251,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -259,7 +260,7 @@ def get_parser():
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -355,8 +356,8 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
@ -387,7 +388,10 @@ def decode_one_batch(
)
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -397,8 +401,12 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -526,8 +534,8 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
@ -643,6 +651,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -737,7 +746,7 @@ def main():
model.device = device
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -212,6 +212,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -247,8 +248,8 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
Used only when --decoding_method is fast_beam_search_LG and
fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores.
""",
)
@ -256,7 +257,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -265,7 +266,7 @@ def get_parser():
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -363,9 +364,10 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -401,7 +403,10 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -411,8 +416,12 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -548,9 +557,10 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -663,6 +673,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -757,7 +768,7 @@ def main():
model.device = device
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -202,6 +202,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -226,7 +227,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
@ -237,7 +238,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -246,7 +247,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -255,7 +256,7 @@ def get_parser():
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is, fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -440,8 +441,8 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
@ -483,7 +484,10 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -494,8 +498,12 @@ def decode_one_batch(
max_states=params.max_states,
temperature=params.temperature,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -714,8 +722,8 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
@ -901,6 +909,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -1002,7 +1011,7 @@ def main():
G = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -243,6 +243,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -267,7 +268,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
@ -278,7 +279,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -287,7 +288,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -296,7 +297,7 @@ def get_parser():
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -394,8 +395,8 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result and timestamps. See above description for the
@ -430,7 +431,10 @@ def decode_one_batch(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
res = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -579,8 +583,8 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
@ -742,6 +746,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -886,7 +891,7 @@ def main():
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -225,6 +225,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -250,7 +251,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
@ -261,7 +262,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -284,7 +285,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -293,7 +294,7 @@ def get_parser():
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -444,8 +445,8 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
@ -482,7 +483,10 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -492,8 +496,12 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -636,8 +644,8 @@ def decode_dataset(
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -753,6 +761,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -917,7 +926,7 @@ def main():
rnn_lm_model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -1369,6 +1369,7 @@ def parse_hyp_and_timestamp(
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -1388,6 +1389,7 @@ def parse_hyp_and_timestamp(
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -1400,7 +1402,10 @@ def parse_hyp_and_timestamp(
N = len(res.tokens)
assert len(res.timestamps) == N
use_word_table = False
if decoding_method == "fast_beam_search_nbest_LG":
if (
decoding_method == "fast_beam_search_nbest_LG"
and decoding_method == "fast_beam_search_LG"
):
assert word_table is not None
use_word_table = True