mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Update results and move rnn directory
This commit is contained in:
parent
71a9c33bca
commit
2eb21668e0
@ -1139,17 +1139,18 @@ You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3x
|
||||
|
||||
#### 2021-11-09
|
||||
|
||||
The best WER, as of 2021-11-09, for the librispeech test dataset is below
|
||||
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring):
|
||||
The best WER, as of 2022-20-06, for the librispeech test dataset is below
|
||||
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring + rnn lm rescoring):
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 2.42 | 5.73 |
|
||||
| WER | 2.32 | 5.39 |
|
||||
|
||||
Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
|
||||
| ngram_lm_scale | attention_scale |
|
||||
|----------------|-----------------|
|
||||
| 2.0 | 2.0 |
|
||||
|
||||
| ngram_lm_scale | attention_scale | rnn_lm_scale |
|
||||
|----------------|-----------------|--------------|
|
||||
| 0.3 | 2.1 | 2.2 |
|
||||
|
||||
|
||||
To reproduce the above result, use the following commands for training:
|
||||
@ -1170,11 +1171,27 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--start-epoch 0 \
|
||||
--num-epochs 90
|
||||
# Note: It trains for 90 epochs, but the best WER is at epoch-77.pt
|
||||
|
||||
# Train the RNN-LM
|
||||
cd icefall
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./rnn_lm/train.py \
|
||||
--exp-dir rnn_lm/exp_2048_3_tied \
|
||||
--start-epoch 0 \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--use-fp16 1 \
|
||||
--embedding-dim 2048 \
|
||||
--hidden-dim 2048 \
|
||||
--num-layers 3 \
|
||||
--batch-size 500 \
|
||||
--tie-weights true
|
||||
```
|
||||
|
||||
and the following command for decoding
|
||||
|
||||
```
|
||||
rnn_dir=$(git rev-parse --show-toplevel)/icefall/rnn_lm
|
||||
./conformer_ctc/decode.py \
|
||||
--exp-dir conformer_ctc/exp_500_att0.8 \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
@ -1184,8 +1201,15 @@ and the following command for decoding
|
||||
--num-paths 1000 \
|
||||
--epoch 77 \
|
||||
--avg 55 \
|
||||
--method attention-decoder \
|
||||
--nbest-scale 0.5
|
||||
--nbest-scale 0.5 \
|
||||
--rnn-lm-exp-dir ${rnn_dir}/exp_2048_3_tied\
|
||||
--rnn-lm-epoch 29 \
|
||||
--rnn-lm-avg 3 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights true \
|
||||
--method rnn-lm
|
||||
```
|
||||
|
||||
You can find the pre-trained model by visiting
|
||||
|
@ -43,6 +43,7 @@ from icefall.decode import (
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
@ -208,19 +209,6 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def get_rnn_lm_model(params: AttributeDict):
|
||||
from rnn_lm.model import RnnLmModel
|
||||
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.num_classes,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
return rnn_lm_model
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
@ -768,7 +756,13 @@ def main():
|
||||
|
||||
rnn_lm_model = None
|
||||
if params.method == "rnn-lm":
|
||||
rnn_lm_model = get_rnn_lm_model(params)
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.num_classes,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
if params.rnn_lm_avg == 1:
|
||||
load_checkpoint(
|
||||
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||
|
@ -41,9 +41,9 @@ dl_dir=$PWD/download
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 5000
|
||||
# 2000
|
||||
# 1000
|
||||
5000
|
||||
2000
|
||||
1000
|
||||
500
|
||||
)
|
||||
|
||||
@ -277,6 +277,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Generate LM training data"
|
||||
|
@ -33,8 +33,8 @@ import torch
|
||||
from rnn_lm.dataset import get_dataloader
|
||||
from rnn_lm.model import RnnLmModel
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, load_averaged_model, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -165,17 +165,12 @@ def main():
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model = load_averaged_model(
|
||||
params.exp_dir, model, params.epoch, params.avg, device
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
num_param_requires_grad = sum(
|
@ -40,7 +40,7 @@ def test_rnn_lm_model():
|
||||
)
|
||||
lengths = torch.tensor([4, 3, 2])
|
||||
nll_loss = model(x, y, lengths)
|
||||
|
||||
print(nll_loss)
|
||||
"""
|
||||
tensor([[1.1180, 1.3059, 1.2426, 1.7773],
|
||||
[1.4231, 1.2783, 1.7321, 0.0000],
|
@ -120,13 +120,6 @@ def get_parser():
|
||||
default=50,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ddp-launch",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True if using torch.distributed.launch",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-data",
|
||||
type=str,
|
||||
@ -165,7 +158,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
default=3,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
|
||||
@ -206,7 +199,7 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 200,
|
||||
"reset_interval": 2000,
|
||||
"valid_interval": 10000,
|
||||
"valid_interval": 5000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user