update codes for merging

This commit is contained in:
luomingshuang 2022-06-14 11:44:03 +08:00
parent c8cb425e51
commit 1c4987f6e6
9 changed files with 241 additions and 86 deletions

View File

@ -23,6 +23,7 @@ We provide the following recipes:
- [Aidatatang_200zh][aidatatang_200zh] - [Aidatatang_200zh][aidatatang_200zh]
- [WenetSpeech][wenetspeech] - [WenetSpeech][wenetspeech]
- [Alimeeting][alimeeting] - [Alimeeting][alimeeting]
- [Aishell4][aishell4]
### yesno ### yesno
@ -262,6 +263,21 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing) We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
### Aishell4
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
The best CER(%) results:
| | test |
|----------------------|--------|
| greedy search | 29.89 |
| fast beam search | 28.91 |
| modified beam search | 29.08 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
## Deployment with C++ ## Deployment with C++
Once you have trained a model in icefall, you may want to deploy it with C++, Once you have trained a model in icefall, you may want to deploy it with C++,
@ -290,6 +306,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2 [Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2
[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2 [WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2 [Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
[yesno]: egs/yesno/ASR [yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR [librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR [aishell]: egs/aishell/ASR
@ -299,5 +316,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[aidatatang_200zh]: egs/aidatatang_200zh/ASR [aidatatang_200zh]: egs/aidatatang_200zh/ASR
[wenetspeech]: egs/wenetspeech/ASR [wenetspeech]: egs/wenetspeech/ASR
[alimeeting]: egs/alimeeting/ASR [alimeeting]: egs/alimeeting/ASR
[aishell4]: egs/aishell4/ASR
[k2]: https://github.com/k2-fsa/k2 [k2]: https://github.com/k2-fsa/k2
) )

View File

@ -0,0 +1,19 @@
# Introduction
This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets).
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

117
egs/aishell4/ASR/RESULTS.md Normal file
View File

@ -0,0 +1,117 @@
## Results
### Aishell4 Char training results (Pruned Transducer Stateless5)
#### 2022-06-13
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/399.
When use-averaged-model=False, the CERs are
| | test | comment |
|------------------------------------|------------|------------------------------------------|
| greedy search | 30.05 | --epoch 30, --avg 25, --max-duration 800 |
| modified beam search (beam size 4) | 29.16 | --epoch 30, --avg 25, --max-duration 800 |
| fast beam search (set as default) | 29.20 | --epoch 30, --avg 25, --max-duration 1500|
When use-averaged-model=True, the CERs are
| | test | comment |
|------------------------------------|------------|----------------------------------------------------------------------|
| greedy search | 29.89 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
| modified beam search (beam size 4) | 28.91 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
| fast beam search (set as default) | 29.08 | --iter 36000, --avg 8, --max-duration 1500 --use-averaged-model=True |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 220 \
--save-every-n 4000
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/tjaVRKERS8C10SzhpBcxSQ/#scalars
When use-averaged-model=False, the decoding command is:
```
epoch=30
avg=25
## greedy search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800
## modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4
## fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
When use-averaged-model=True, the decoding command is:
```
iter=36000
avg=8
## greedy search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--use-averaged-model True
## modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4 \
--use-averaged-model True
## fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--use-averaged-model True
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aishell4_pruned_transducer_stateless5>

View File

@ -23,7 +23,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate, CutConcatenate,
CutMix, CutMix,
@ -222,7 +222,7 @@ class Aishell4AsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest_lazy( cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz" self.args.manifest_dir / "musan_cuts.jsonl.gz"
) )

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao) # Zengwei Yao,
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -17,43 +18,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Usage: When use-averaged-model=True, usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --iter 36000 \
--avg 15 \ --avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --max-duration 800 \
--decoding-method greedy_search --decoding-method greedy_search \
--use-averaged-model True
(2) beam search (not recommended) (2) modified beam search
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --iter 36000 \
--avg 15 \ --avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --max-duration 800 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4 \
--use-averaged-model True
(4) fast beam search (3) fast beam search
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --iter 36000 \
--avg 15 \ --avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --max-duration 800 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8 \
--use-averaged-model True
""" """

View File

@ -22,7 +22,7 @@
Usage: Usage:
./pruned_transducer_stateless5/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --lang-dir data/lang_char \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -34,21 +34,20 @@ you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/aishell4/ASR
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search \ --decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model --lang-dir data/lang_char
""" """
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -58,6 +57,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import str2bool from icefall.utils import str2bool
@ -115,10 +115,13 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
) )
parser.add_argument( parser.add_argument(
@ -157,12 +160,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -15,30 +15,33 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Usage: When use-averaged-model=True, usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --lang-dir data/lang_char \
--method greedy_search \ --decoding-method greedy_search \
--use-averaged-model True \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(2) beam search (2) beam search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --lang-dir data/lang_char \
--method beam_search \ --use-averaged-model True \
--decoding-method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(3) modified beam search (3) modified beam search (not suggest)
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --lang-dir data/lang_char \
--method modified_beam_search \ --use-averaged-model True \
--decoding-method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -46,8 +49,9 @@ Usage:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --lang-dir data/lang_char \
--method fast_beam_search \ --use-averaged-model True \
--decoding-method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -66,7 +70,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -79,6 +82,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -95,13 +100,14 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to lang.
""",
) )
parser.add_argument( parser.add_argument(
"--method", "--decoding-method",
type=str, type=str,
default="greedy_search", default="greedy_search",
help="""Possible values are: help="""Possible values are:
@ -134,7 +140,7 @@ def get_parser():
type=int, type=int,
default=4, default=4,
help="""An integer indicating how many candidates we will keep for each help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or frame. Used only when --decoding-method is beam_search or
modified_beam_search.""", modified_beam_search.""",
) )
@ -145,21 +151,21 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --method is fast_beam_search""", Used only when --decoding-method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=4,
help="""Used only when --method is fast_beam_search""", help="""Used only when --decoding-method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=8,
help="""Used only when --method is fast_beam_search""", help="""Used only when --decoding-method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -174,7 +180,7 @@ def get_parser():
type=int, type=int,
default=1, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --decoding-method is greedy_search.
""", """,
) )
@ -216,13 +222,9 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")
@ -276,12 +278,12 @@ def main():
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []
msg = f"Using {params.method}" msg = f"Using {params.decoding_method}"
if params.method == "beam_search": if params.decoding_method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
if params.method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
@ -292,9 +294,9 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -302,37 +304,41 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search( hyp = greedy_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
) )
elif params.method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(
f"Unsupported decoding-method: {params.decoding_method}"
hyps.append(sp.decode(hyp).split()) )
hyps.append([lexicon.token_table[idx] for idx in hyp])
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -19,8 +19,8 @@
""" """
To run this file, do: To run this file, do:
cd icefall/egs/librispeech/ASR cd icefall/egs/aishell4/ASR
python ./pruned_transducer_stateless4/test_model.py python ./pruned_transducer_stateless5/test_model.py
""" """
from train import get_params, get_transducer_model from train import get_params, get_transducer_model

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang, # Wei Kang,
# Mingshuang Luo,) # Mingshuang Luo,
# Zengwei Yao) # Zengwei Yao)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
@ -396,7 +396,7 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
# parameters for Noam # parameters for Noam
"model_warm_step": 50, # arg given to model, not for lrate "model_warm_step": 400, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )