Shallow fusion for Aishell (#954)

* add shallow fusion and LODR for aishell

* update RESULTS

* add save by iterations
This commit is contained in:
marcoyang1998 2023-04-03 16:20:29 +08:00 committed by GitHub
parent 46bf6df62f
commit d337398d29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 647 additions and 36 deletions

View File

@ -15,6 +15,8 @@ It uses pruned RNN-T.
|------------------------|------|------|---------------------------------------|
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 |
| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 |
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
Training command is:
@ -73,6 +75,78 @@ for epoch in 29; do
done
```
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
please use the following command:
```bash
# download pre-trained model
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/
pushd ${aishell_exp}/exp
ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt
popd
# download RNN LM
git lfs install
git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm
rnnlm_dir=icefall-aishell-rnn-lm
# RNNLM shallow fusion
for lm_scale in $(seq 0.26 0.02 0.34); do
python ./pruned_transducer_stateless3/decode.py \
--epoch 99 \
--avg 1 \
--lang-dir ${aishell_exp}/data/lang_char \
--exp-dir ${aishell_exp}/exp \
--use-averaged-model False \
--decoding-method modified_beam_search_lm_shallow_fusion \
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir ${rnnlm_dir}/exp \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 2 \
--lm-vocab-size 4336
done
# RNNLM Low-order density ratio (LODR) with a 2-gram
cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt
for lm_scale in 0.48; do
for LODR_scale in -0.28; do
python ./pruned_transducer_stateless3/decode.py \
--epoch 99 \
--avg 1 \
--lang-dir ${aishell_exp}/data/lang_char \
--exp-dir ${aishell_exp}/exp \
--use-averaged-model False \
--decoding-method modified_beam_search_LODR \
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir ${rnnlm_dir}/exp \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 2 \
--lm-vocab-size 4336 \
--tokens-ngram 2 \
--backoff-id 4336 \
--ngram-lm-scale $LODR_scale
done
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20>

0
egs/aishell/ASR/local/prepare_char_lm_training_data.py Normal file → Executable file
View File

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/sort_lm_training_data.py

View File

@ -231,11 +231,13 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
fi
# training words
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $dl_dir/lm/aishell-train-word.txt \
--lm-archive $out_dir/lm_data.pt
# valid words
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
@ -249,6 +251,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
--lm-data $dl_dir/lm/aishell-valid-word.txt \
--lm-archive $out_dir/lm_data_valid.pt
# test words
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
@ -303,9 +306,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
--hidden-dim 512 \
--num-layers 2 \
--batch-size 400 \
--exp-dir rnnlm_char/exp \
--lm-data data/lm_training_char/sorted_lm_data.pt \
--lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \
--exp-dir rnnlm_char/exp_aishell1_small \
--lm-data data/lm_char/sorted_lm_data_aishell1.pt \
--lm-data-valid data/lm_char/sorted_lm_data_valid.pt \
--vocab-size 4336 \
--master-port 12345
fi

View File

@ -54,6 +54,40 @@ Usage:
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) modified beam search (with LM shallow fusion)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.3 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
(6) modified beam search with LM shallow fusion + LODR
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--max-duration 600 \
--exp-dir ./pruned_transducer_stateless3/exp \
--decoding-method modified_beam_search_LODR \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.48 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
--tokens-ngram 2 \
--ngram-lm-scale -0.28 \
"""
@ -74,9 +108,12 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -212,6 +249,60 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=2,
help="""Token Ngram used for rescoring.
Used only when the decoding method is
modified_beam_search_ngram_rescoring""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="""ID of the backoff symbol.
Used only when the decoding method is
modified_beam_search_ngram_rescoring""",
)
add_model_arguments(parser)
return parser
@ -223,6 +314,9 @@ def decode_one_batch(
token_table: k2.SymbolTable,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -287,6 +381,24 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
)
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
@ -334,6 +446,9 @@ def decode_dataset(
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -379,6 +494,9 @@ def decode_dataset(
token_table=token_table,
decoding_graph=decoding_graph,
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
for name, hyps in hyps_dict.items():
@ -445,6 +563,7 @@ def save_results(
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
@ -458,6 +577,8 @@ def main():
"beam_search",
"fast_beam_search",
"modified_beam_search",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -479,6 +600,19 @@ def main():
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
if "ngram" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if params.use_shallow_fusion:
if params.lm_type == "rnn":
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
elif params.lm_type == "transformer":
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -588,6 +722,35 @@ def main():
else:
decoding_graph = None
# only load N-gram LM when needed
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
lm_filename,
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
# only load the neural network LM if doing shallow fusion
if params.use_shallow_fusion:
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -610,6 +773,9 @@ def main():
model=model,
token_table=lexicon.token_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
save_results(

View File

@ -1 +0,0 @@
/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1

View File

@ -550,7 +550,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
@ -561,7 +560,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,

View File

@ -1863,7 +1863,6 @@ def modified_beam_search_LODR(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
LODR_lm: NgramLm,
LODR_lm_scale: float,
LM: LmScorer,
@ -1883,8 +1882,6 @@ def modified_beam_search_LODR(
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
LODR_lm:
A low order n-gram LM, whose score will be subtracted during shallow fusion
LODR_lm_scale:
@ -1912,7 +1909,7 @@ def modified_beam_search_LODR(
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
sos_id = getattr(LM, "sos_id", 1)
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
@ -2137,7 +2134,6 @@ def modified_beam_search_lm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
LM: LmScorer,
beam: int = 4,
return_timestamps: bool = False,
@ -2176,7 +2172,7 @@ def modified_beam_search_lm_shallow_fusion(
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
sos_id = getattr(LM, "sos_id", 1)
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device

View File

@ -675,7 +675,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
@ -686,7 +685,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,

View File

@ -586,7 +586,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
@ -597,7 +596,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,

View File

@ -533,7 +533,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
@ -544,7 +543,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,

View File

@ -40,8 +40,8 @@ from tqdm import tqdm
# and 'data()' is only supported in static graph mode. So if you
# want to use this api, should call 'paddle.enable_static()' before
# this api to enter static graph mode.
paddle.enable_static()
paddle.disable_signal_handler()
# paddle.enable_static()
# paddle.disable_signal_handler()
jieba.enable_paddle()

View File

@ -261,3 +261,107 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
log "Stage 18: Compile LG"
python ./local/compile_lg.py --lang-dir $lang_char_dir
fi
# prepare RNNLM data
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "Stage 19: Prepare LM training data"
log "Processing char based data"
text_out_dir=data/lm_char
mkdir -p $text_out_dir
log "Genearating training text data"
if [ ! -f $text_out_dir/lm_data.pt ]; then
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $lang_char_dir/text_words_segmentation \
--lm-archive $text_out_dir/lm_data.pt
fi
log "Generating DEV text data"
# prepare validation text data
if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then
valid_text=${text_out_dir}/
gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $text_out_dir/valid_text
python3 ./local/text2segments.py \
--num-process $nj \
--input-file $text_out_dir/valid_text \
--output-file $text_out_dir/valid_text_words_segmentation
fi
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $text_out_dir/valid_text_words_segmentation \
--lm-archive $text_out_dir/lm_data_valid.pt
# prepare TEST text data
if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then
log "Prepare text for test set."
for test_set in TEST_MEETING TEST_NET; do
gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text
python3 ./local/text2segments.py \
--num-process $nj \
--input-file $text_out_dir/${test_set}_text \
--output-file $text_out_dir/${test_set}_text_words_segmentation
done
cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation
fi
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $text_out_dir/test_text_words_segmentation \
--lm-archive $text_out_dir/lm_data_test.pt
fi
# sort RNNLM data
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
text_out_dir=data/lm_char
log "Sort lm data"
./local/sort_lm_training_data.py \
--in-lm-data $text_out_dir/lm_data.pt \
--out-lm-data $text_out_dir/sorted_lm_data.pt \
--out-statistics $text_out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $text_out_dir/lm_data_valid.pt \
--out-lm-data $text_out_dir/sorted_lm_data-valid.pt \
--out-statistics $text_out_dir/statistics-valid.txt
./local/sort_lm_training_data.py \
--in-lm-data $text_out_dir/lm_data_test.pt \
--out-lm-data $text_out_dir/sorted_lm_data-test.pt \
--out-statistics $text_out_dir/statistics-test.txt
fi
export CUDA_VISIBLE_DEVICES="0,1"
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
log "Stage 21: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \
--world-size 2 \
--num-epochs 20 \
--use-fp16 0 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 2 \
--batch-size 400 \
--exp-dir rnnlm_char/exp \
--lm-data data/lm_char/sorted_lm_data.pt \
--lm-data-valid data/lm_char/sorted_lm_data-valid.pt \
--vocab-size 4336 \
--master-port 12340
fi

View File

@ -2,6 +2,7 @@
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -91,6 +92,22 @@ When training with the L subset, the streaming usage:
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64
(4) modified beam search with RNNLM shallow fusion
./pruned_transducer_stateless5/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.3 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
"""
@ -111,9 +128,12 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -224,6 +244,16 @@ def get_parser():
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
@ -277,6 +307,50 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=3,
help="""Token Ngram used for rescoring.
Used only when the decoding method is
modified_beam_search_ngram_rescoring, or LODR
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="""ID of the backoff symbol.
Used only when the decoding method is
modified_beam_search_ngram_rescoring""",
)
add_model_arguments(parser)
return parser
@ -288,6 +362,9 @@ def decode_one_batch(
lexicon: Lexicon,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -374,6 +451,28 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
@ -419,6 +518,9 @@ def decode_dataset(
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -432,6 +534,8 @@ def decode_dataset(
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -449,7 +553,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -463,6 +567,9 @@ def decode_dataset(
lexicon=lexicon,
decoding_graph=decoding_graph,
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
for name, hyps in hyps_dict.items():
@ -524,6 +631,7 @@ def save_results(
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -535,6 +643,8 @@ def main():
"beam_search",
"fast_beam_search",
"modified_beam_search",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -549,6 +659,22 @@ def main():
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if "ngram" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if params.use_shallow_fusion:
if params.lm_type == "rnn":
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
elif params.lm_type == "transformer":
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -558,6 +684,7 @@ def main():
logging.info(f"Device: {device}")
# import pdb; pdb.set_trace()
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
@ -652,6 +779,37 @@ def main():
model.to(device)
model.eval()
model.device = device
# only load N-gram LM when needed
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
# import pdb; pdb.set_trace()
# only load the neural network LM if doing shallow fusion
if params.use_shallow_fusion:
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
num_param = sum([p.numel() for p in LM.parameters()])
logging.info(f"Number of model parameters: {num_param}")
else:
LM = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
@ -684,6 +842,9 @@ def main():
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
save_results(
params=params,

View File

@ -50,7 +50,7 @@ class LmScorer(torch.nn.Module):
def add_arguments(cls, parser):
# LM general arguments
parser.add_argument(
"--vocab-size",
"--lm-vocab-size",
type=int,
default=500,
)

View File

@ -33,7 +33,7 @@ import torch
from dataset import get_dataloader
from model import RnnLmModel
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, setup_logger, str2bool
@ -49,6 +49,7 @@ def get_parser():
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
@ -58,6 +59,16 @@ def get_parser():
"'--epoch'. ",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -154,7 +165,14 @@ def main():
params = AttributeDict(vars(args))
setup_logger(f"{params.exp_dir}/log-ppl/")
if params.iter > 0:
setup_logger(
f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}"
)
else:
setup_logger(
f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}"
)
logging.info("Computing perplexity started")
logging.info(params)
@ -173,19 +191,39 @@ def main():
tie_weights=params.tie_weights,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
if i >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum(

View File

@ -25,7 +25,7 @@ from pathlib import Path
import torch
from model import RnnLmModel
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, load_averaged_model, str2bool
@ -51,6 +51,16 @@ def get_parser():
"'--epoch'. ",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--vocab-size",
type=int,
@ -133,11 +143,36 @@ def main():
model.to(device)
if params.avg == 1:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")

View File

@ -49,6 +49,7 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -178,6 +179,33 @@ def get_parser():
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--lr",
type=float,
default=1e-3,
)
parser.add_argument(
"--max-sent-len",
type=int,
default=200,
help="""Maximum number of tokens in a sentence. This is used
to adjust batch-size dynamically""",
)
parser.add_argument(
"--save-every-n",
type=int,
default=2000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
return parser
@ -190,16 +218,15 @@ def get_params() -> AttributeDict:
"sos_id": 1,
"eos_id": 1,
"blank_id": 0,
"lr": 1e-3,
"weight_decay": 1e-6,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 200,
"log_interval": 100,
"reset_interval": 2000,
"valid_interval": 5000,
"valid_interval": 200,
"env_info": get_env_info(),
}
)
@ -382,6 +409,7 @@ def train_one_epoch(
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
@ -430,6 +458,19 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer=optimizer,
rank=rank,
)
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
@ -580,6 +621,7 @@ def run(rank, world_size, args):
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
save_checkpoint(