mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
update codes for merging
This commit is contained in:
parent
c8cb425e51
commit
1c4987f6e6
18
README.md
18
README.md
@ -23,6 +23,7 @@ We provide the following recipes:
|
|||||||
- [Aidatatang_200zh][aidatatang_200zh]
|
- [Aidatatang_200zh][aidatatang_200zh]
|
||||||
- [WenetSpeech][wenetspeech]
|
- [WenetSpeech][wenetspeech]
|
||||||
- [Alimeeting][alimeeting]
|
- [Alimeeting][alimeeting]
|
||||||
|
- [Aishell4][aishell4]
|
||||||
|
|
||||||
### yesno
|
### yesno
|
||||||
|
|
||||||
@ -262,6 +263,21 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
|||||||
|
|
||||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
|
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
|
||||||
|
|
||||||
|
### Aishell4
|
||||||
|
|
||||||
|
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
|
||||||
|
|
||||||
|
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
|
||||||
|
|
||||||
|
The best CER(%) results:
|
||||||
|
| | test |
|
||||||
|
|----------------------|--------|
|
||||||
|
| greedy search | 29.89 |
|
||||||
|
| fast beam search | 28.91 |
|
||||||
|
| modified beam search | 29.08 |
|
||||||
|
|
||||||
|
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
|
||||||
|
|
||||||
## Deployment with C++
|
## Deployment with C++
|
||||||
|
|
||||||
Once you have trained a model in icefall, you may want to deploy it with C++,
|
Once you have trained a model in icefall, you may want to deploy it with C++,
|
||||||
@ -290,6 +306,7 @@ Please see: [
|
)
|
||||||
|
19
egs/aishell4/ASR/README.md
Normal file
19
egs/aishell4/ASR/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
# Introduction
|
||||||
|
|
||||||
|
This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets).
|
||||||
|
|
||||||
|
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||||
|
|
||||||
|
# Transducers
|
||||||
|
|
||||||
|
There are various folders containing the name `transducer` in this folder.
|
||||||
|
The following table lists the differences among them.
|
||||||
|
|
||||||
|
| | Encoder | Decoder | Comment |
|
||||||
|
|---------------------------------------|---------------------|--------------------|-----------------------------|
|
||||||
|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
|
||||||
|
|
||||||
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
|
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||||
|
We place an additional Conv1d layer right after the input embedding layer.
|
117
egs/aishell4/ASR/RESULTS.md
Normal file
117
egs/aishell4/ASR/RESULTS.md
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
## Results
|
||||||
|
|
||||||
|
### Aishell4 Char training results (Pruned Transducer Stateless5)
|
||||||
|
|
||||||
|
#### 2022-06-13
|
||||||
|
|
||||||
|
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/399.
|
||||||
|
|
||||||
|
When use-averaged-model=False, the CERs are
|
||||||
|
| | test | comment |
|
||||||
|
|------------------------------------|------------|------------------------------------------|
|
||||||
|
| greedy search | 30.05 | --epoch 30, --avg 25, --max-duration 800 |
|
||||||
|
| modified beam search (beam size 4) | 29.16 | --epoch 30, --avg 25, --max-duration 800 |
|
||||||
|
| fast beam search (set as default) | 29.20 | --epoch 30, --avg 25, --max-duration 1500|
|
||||||
|
|
||||||
|
When use-averaged-model=True, the CERs are
|
||||||
|
| | test | comment |
|
||||||
|
|------------------------------------|------------|----------------------------------------------------------------------|
|
||||||
|
| greedy search | 29.89 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
|
||||||
|
| modified beam search (beam size 4) | 28.91 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
|
||||||
|
| fast beam search (set as default) | 29.08 | --iter 36000, --avg 8, --max-duration 1500 --use-averaged-model=True |
|
||||||
|
|
||||||
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless5/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 220 \
|
||||||
|
--save-every-n 4000
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard training log can be found at
|
||||||
|
https://tensorboard.dev/experiment/tjaVRKERS8C10SzhpBcxSQ/#scalars
|
||||||
|
|
||||||
|
When use-averaged-model=False, the decoding command is:
|
||||||
|
```
|
||||||
|
epoch=30
|
||||||
|
avg=25
|
||||||
|
|
||||||
|
## greedy search
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
|
--lang-dir ./data/lang_char \
|
||||||
|
--max-duration 800
|
||||||
|
|
||||||
|
## modified beam search
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
|
--lang-dir ./data/lang_char \
|
||||||
|
--max-duration 800 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
## fast beam search
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
|
--lang-dir ./data/lang_char \
|
||||||
|
--max-duration 1500 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
|
```
|
||||||
|
|
||||||
|
When use-averaged-model=True, the decoding command is:
|
||||||
|
```
|
||||||
|
iter=36000
|
||||||
|
avg=8
|
||||||
|
|
||||||
|
## greedy search
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
|
--lang-dir ./data/lang_char \
|
||||||
|
--max-duration 800 \
|
||||||
|
--use-averaged-model True
|
||||||
|
|
||||||
|
## modified beam search
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
|
--lang-dir ./data/lang_char \
|
||||||
|
--max-duration 800 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
--use-averaged-model True
|
||||||
|
|
||||||
|
## fast beam search
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
|
--lang-dir ./data/lang_char \
|
||||||
|
--max-duration 1500 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--use-averaged-model True
|
||||||
|
```
|
||||||
|
|
||||||
|
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aishell4_pruned_transducer_stateless5>
|
@ -23,7 +23,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
@ -222,7 +222,7 @@ class Aishell4AsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest_lazy(
|
cuts_musan = load_manifest(
|
||||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao,
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -17,43 +18,37 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
When use-averaged-model=True, usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless5/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 28 \
|
--iter 36000 \
|
||||||
--avg 15 \
|
--avg 8 \
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--max-duration 600 \
|
--max-duration 800 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search \
|
||||||
|
--use-averaged-model True
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) modified beam search
|
||||||
./pruned_transducer_stateless5/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 28 \
|
--iter 36000 \
|
||||||
--avg 15 \
|
--avg 8 \
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--max-duration 600 \
|
--max-duration 800 \
|
||||||
--decoding-method beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
|
|
||||||
(3) modified beam search
|
|
||||||
./pruned_transducer_stateless5/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4 \
|
||||||
|
--use-averaged-model True
|
||||||
|
|
||||||
(4) fast beam search
|
(3) fast beam search
|
||||||
./pruned_transducer_stateless5/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 28 \
|
--iter 36000 \
|
||||||
--avg 15 \
|
--avg 8 \
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--max-duration 600 \
|
--max-duration 800 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8 \
|
||||||
|
--use-averaged-model True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless5/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--lang-dir data/lang_char \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -34,21 +34,20 @@ you can do:
|
|||||||
cd /path/to/exp_dir
|
cd /path/to/exp_dir
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
cd /path/to/egs/librispeech/ASR
|
cd /path/to/egs/aishell4/ASR
|
||||||
./pruned_transducer_stateless5/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search \
|
--decoding-method greedy_search \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model
|
--lang-dir data/lang_char
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -58,6 +57,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -115,10 +115,13 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_char",
|
||||||
help="Path to the BPE model",
|
help="""The lang dir
|
||||||
|
It contains language related input files such as
|
||||||
|
"lexicon.txt"
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -157,12 +160,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
lexicon = Lexicon(params.lang_dir)
|
||||||
sp.load(params.bpe_model)
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -15,30 +15,33 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
When use-averaged-model=True, usage:
|
||||||
|
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless5/pretrained.py \
|
./pruned_transducer_stateless5/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--lang-dir data/lang_char \
|
||||||
--method greedy_search \
|
--decoding-method greedy_search \
|
||||||
|
--use-averaged-model True \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless5/pretrained.py \
|
./pruned_transducer_stateless5/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--lang-dir data/lang_char \
|
||||||
--method beam_search \
|
--use-averaged-model True \
|
||||||
|
--decoding-method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search (not suggest)
|
||||||
./pruned_transducer_stateless5/pretrained.py \
|
./pruned_transducer_stateless5/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--lang-dir data/lang_char \
|
||||||
--method modified_beam_search \
|
--use-averaged-model True \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -46,8 +49,9 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless5/pretrained.py \
|
./pruned_transducer_stateless5/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--lang-dir data/lang_char \
|
||||||
--method fast_beam_search \
|
--use-averaged-model True \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -66,7 +70,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -79,6 +82,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -95,13 +100,14 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to lang.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
default="greedy_search",
|
default="greedy_search",
|
||||||
help="""Possible values are:
|
help="""Possible values are:
|
||||||
@ -134,7 +140,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An integer indicating how many candidates we will keep for each
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
frame. Used only when --method is beam_search or
|
frame. Used only when --decoding-method is beam_search or
|
||||||
modified_beam_search.""",
|
modified_beam_search.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -145,21 +151,21 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --method is fast_beam_search""",
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --method is fast_beam_search""",
|
help="""Used only when --decoding-method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-states",
|
"--max-states",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --method is fast_beam_search""",
|
help="""Used only when --decoding-method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -174,7 +180,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="""Maximum number of symbols per frame. Used only when
|
help="""Maximum number of symbols per frame. Used only when
|
||||||
--method is greedy_search.
|
--decoding-method is greedy_search.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -216,13 +222,9 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
lexicon = Lexicon(params.lang_dir)
|
||||||
sp.load(params.bpe_model)
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -276,12 +278,12 @@ def main():
|
|||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
msg = f"Using {params.method}"
|
msg = f"Using {params.decoding_method}"
|
||||||
if params.method == "beam_search":
|
if params.decoding_method == "beam_search":
|
||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
@ -292,9 +294,9 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append(hyp.split())
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -302,37 +304,41 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append(hyp.split())
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append(hyp.split())
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(
|
hyp = greedy_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
)
|
)
|
||||||
elif params.method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(
|
||||||
|
f"Unsupported decoding-method: {params.decoding_method}"
|
||||||
hyps.append(sp.decode(hyp).split())
|
)
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp])
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
@ -19,8 +19,8 @@
|
|||||||
"""
|
"""
|
||||||
To run this file, do:
|
To run this file, do:
|
||||||
|
|
||||||
cd icefall/egs/librispeech/ASR
|
cd icefall/egs/aishell4/ASR
|
||||||
python ./pruned_transducer_stateless4/test_model.py
|
python ./pruned_transducer_stateless5/test_model.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang,
|
# Wei Kang,
|
||||||
# Mingshuang Luo,)
|
# Mingshuang Luo,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
@ -396,7 +396,7 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 50, # arg given to model, not for lrate
|
"model_warm_step": 400, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user