mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
resolve conflicts
This commit is contained in:
commit
a2d7095c1c
@ -83,4 +83,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
||||
done
|
||||
|
||||
rm pruned_transducer_stateless2/exp/*.pt
|
||||
rm -r data/lang_bpe_500
|
||||
fi
|
||||
|
@ -82,4 +82,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
||||
done
|
||||
|
||||
rm pruned_transducer_stateless3/exp/*.pt
|
||||
rm -r data/lang_bpe_500
|
||||
fi
|
||||
|
@ -82,7 +82,7 @@ The WER for this model is:
|
||||
|-----|------------|------------|
|
||||
| WER | 6.59 | 17.69 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
|
||||
|
||||
|
||||
#### Transducer: Conformer encoder + LSTM decoder
|
||||
@ -162,7 +162,7 @@ The CER for this model is:
|
||||
|-----|-------|
|
||||
| CER | 10.16 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z?usp=sharing)
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
|
||||
|
||||
### TIMIT
|
||||
|
||||
|
@ -72,14 +72,14 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
|
||||
```
|
||||
|
||||
### Tips:
|
||||
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/docker}:{/path/in/host/machine}`.
|
||||
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
|
||||
|
||||
2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
|
||||
|
||||
Overall, your docker run command should look like this.
|
||||
|
||||
```bash
|
||||
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/docker}:{/path/in/host/machine} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
|
||||
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
|
||||
```
|
||||
|
||||
You can explore more docker run options [here](https://docs.docker.com/engine/reference/commandline/run/) to suit your environment.
|
||||
|
@ -51,8 +51,9 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
|
||||
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
|
||||
cd -
|
||||
|
||||
RUN pip install kaldiio graphviz && \
|
||||
conda install -y -c pytorch torchaudio
|
||||
RUN conda install -y -c pytorch torchaudio=0.12 && \
|
||||
pip install graphviz
|
||||
|
||||
|
||||
#install k2 from source
|
||||
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
|
||||
|
@ -69,8 +69,8 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
|
||||
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
|
||||
cd -
|
||||
|
||||
RUN pip install kaldiio graphviz && \
|
||||
conda install -y -c pytorch torchaudio=0.7.1
|
||||
RUN conda install -y -c pytorch torchaudio=0.7.1 && \
|
||||
pip install graphviz
|
||||
|
||||
#install k2 from source
|
||||
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
|
||||
|
@ -498,7 +498,7 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained
|
||||
|aishell asr conformer ctc colab notebook|
|
||||
|
||||
.. |aishell asr conformer ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
|
||||
:target: https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z
|
||||
:target: https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing
|
||||
|
||||
**Congratulations!** You have finished the aishell ASR recipe with
|
||||
TDNN-LSTM CTC models in ``icefall``.
|
||||
|
@ -398,7 +398,7 @@ We provide a colab notebook for decoding with pre-trained model.
|
||||
|librispeech tdnn_lstm_ctc colab notebook|
|
||||
|
||||
.. |librispeech tdnn_lstm_ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
|
||||
:target: https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd
|
||||
:target: https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing
|
||||
|
||||
|
||||
**Congratulations!** You have finished the TDNN-LSTM-CTC recipe on librispeech in ``icefall``.
|
||||
|
@ -206,6 +206,7 @@ def get_parser():
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
@ -230,7 +231,7 @@ def get_parser():
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
@ -241,7 +242,7 @@ def get_parser():
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
@ -250,7 +251,7 @@ def get_parser():
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -259,7 +260,7 @@ def get_parser():
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -355,8 +356,8 @@ def decode_one_batch(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
@ -387,7 +388,10 @@ def decode_one_batch(
|
||||
)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
@ -397,8 +401,12 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
@ -526,8 +534,8 @@ def decode_dataset(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
@ -643,6 +651,7 @@ def main():
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
@ -737,7 +746,7 @@ def main():
|
||||
model.device = device
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
|
@ -212,6 +212,7 @@ def get_parser():
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
@ -247,8 +248,8 @@ def get_parser():
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
Used only when --decoding_method is fast_beam_search_LG and
|
||||
fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -256,7 +257,7 @@ def get_parser():
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -265,7 +266,7 @@ def get_parser():
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -363,9 +364,10 @@ def decode_one_batch(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
|
||||
fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -401,7 +403,10 @@ def decode_one_batch(
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
@ -411,8 +416,12 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
@ -548,9 +557,10 @@ def decode_dataset(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
|
||||
fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -663,6 +673,7 @@ def main():
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
@ -757,7 +768,7 @@ def main():
|
||||
model.device = device
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
|
@ -202,6 +202,7 @@ def get_parser():
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
@ -226,7 +227,7 @@ def get_parser():
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
@ -237,7 +238,7 @@ def get_parser():
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
@ -246,7 +247,7 @@ def get_parser():
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -255,7 +256,7 @@ def get_parser():
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is, fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -440,8 +441,8 @@ def decode_one_batch(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
G:
|
||||
Optional. Used only when decoding method is fast_beam_search,
|
||||
@ -483,7 +484,10 @@ def decode_one_batch(
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
@ -494,8 +498,12 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
temperature=params.temperature,
|
||||
)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
@ -714,8 +722,8 @@ def decode_dataset(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
G:
|
||||
Optional. Used only when decoding method is fast_beam_search,
|
||||
@ -901,6 +909,7 @@ def main():
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
@ -1002,7 +1011,7 @@ def main():
|
||||
|
||||
G = None
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
|
@ -243,6 +243,7 @@ def get_parser():
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
@ -267,7 +268,7 @@ def get_parser():
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
@ -278,7 +279,7 @@ def get_parser():
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
@ -287,7 +288,7 @@ def get_parser():
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -296,7 +297,7 @@ def get_parser():
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -394,8 +395,8 @@ def decode_one_batch(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result and timestamps. See above description for the
|
||||
@ -430,7 +431,10 @@ def decode_one_batch(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
res = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
@ -579,8 +583,8 @@ def decode_dataset(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
@ -742,6 +746,7 @@ def main():
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
@ -886,7 +891,7 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
|
@ -225,6 +225,7 @@ def get_parser():
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
@ -250,7 +251,7 @@ def get_parser():
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
@ -261,7 +262,7 @@ def get_parser():
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
@ -284,7 +285,7 @@ def get_parser():
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -293,7 +294,7 @@ def get_parser():
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
@ -444,8 +445,8 @@ def decode_one_batch(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
@ -482,7 +483,10 @@ def decode_one_batch(
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
@ -492,8 +496,12 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
@ -636,8 +644,8 @@ def decode_dataset(
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -753,6 +761,7 @@ def main():
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
@ -917,7 +926,7 @@ def main():
|
||||
rnn_lm_model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
|
@ -1369,6 +1369,7 @@ def parse_hyp_and_timestamp(
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
@ -1388,6 +1389,7 @@ def parse_hyp_and_timestamp(
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
@ -1400,7 +1402,10 @@ def parse_hyp_and_timestamp(
|
||||
N = len(res.tokens)
|
||||
assert len(res.timestamps) == N
|
||||
use_word_table = False
|
||||
if decoding_method == "fast_beam_search_nbest_LG":
|
||||
if (
|
||||
decoding_method == "fast_beam_search_nbest_LG"
|
||||
and decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
assert word_table is not None
|
||||
use_word_table = True
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user