Merge remote-tracking branch 'dan/master' into ctc-ali

This commit is contained in:
Fangjun Kuang 2021-09-13 11:06:51 +08:00
commit 5072e28afb
5 changed files with 47 additions and 5 deletions

View File

@ -53,6 +53,20 @@ jobs:
# icefall requirements
pip install -r requirements.txt
- name: Install graphviz
if: startsWith(matrix.os, 'ubuntu')
shell: bash
run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Install graphviz
if: startsWith(matrix.os, 'macos')
shell: bash
run: |
python3 -m pip install -qq graphviz
brew install -q graphviz
- name: Run tests
if: startsWith(matrix.os, 'ubuntu')
run: |

View File

@ -21,6 +21,32 @@ To get more unique paths, we scaled the lattice.scores with 0.5 (see https://git
|test-clean|1.3|1.2|
|test-other|1.2|1.1|
You can use the following commands to reproduce our results:
```bash
git clone https://github.com/k2-fsa/icefall
cd icefall
# It was using ef233486, you may not need to switch to it
# git checkout ef233486
cd egs/librispeech/ASR
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python conformer_ctc/train.py --bucketing-sampler True \
--concatenate-cuts False \
--max-duration 200 \
--full-libri True \
--world-size 4
python conformer_ctc/decode.py --lattice-score-scale 0.5 \
--epoch 34 \
--avg 20 \
--method attention-decoder \
--max-duration 20 \
--num-paths 100
```
### LibriSpeech training results (Tdnn-Lstm)
#### 2021-08-24

View File

@ -108,7 +108,7 @@ def get_parser():
parser.add_argument(
"--lattice-score-scale",
type=float,
default=1.0,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
@ -278,7 +278,8 @@ def decode_one_batch(
"attention-decoder",
]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":

View File

@ -82,14 +82,14 @@ class LibriSpeechAsrDataModule(DataModule):
group.add_argument(
"--max-duration",
type=int,
default=500.0,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=False,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)

View File

@ -206,7 +206,8 @@ def decode_one_batch(
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":