Merge branch 'dev/lm_multi_zh-hans' of https://github.com/JinZr/icefall into dev/lm_multi_zh-hans

This commit is contained in:
JinZr 2023-11-09 10:53:05 +08:00
commit de3daf6496
5 changed files with 443 additions and 92 deletions

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey # Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
# Fangjun Kuang) # Fangjun Kuang,
# Zengrui Jin)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #

View File

@ -426,7 +426,7 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
out_dir=data/lm_training_bpe_${vocab_size} out_dir=data/lm_training_bpe_${vocab_size}
python ../../../icefall/rnn_lm/train.py \ python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \ --start-epoch 0 \
--world-size 1 \ --world-size 2 \
--use-fp16 0 \ --use-fp16 0 \
--embedding-dim 2048 \ --embedding-dim 2048 \
--hidden-dim 2048 \ --hidden-dim 2048 \
@ -435,8 +435,7 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
--exp-dir rnnlm_bpe_${vocab_size}/exp \ --exp-dir rnnlm_bpe_${vocab_size}/exp \
--lm-data $out_dir/sorted_lm_data.pt \ --lm-data $out_dir/sorted_lm_data.pt \
--lm-data-valid $out_dir/sorted_lm_data-dev.pt \ --lm-data-valid $out_dir/sorted_lm_data-dev.pt \
--vocab-size $vocab_size \ --vocab-size $vocab_size
--master-port 12345
done done
fi fi

View File

@ -1,93 +1,229 @@
for subset in train dev test; do cd data/
gunzip -c aidatatang_200zh/aidatatang_supervisions_${subset}.jsonl.gz \
log "Preparing LM data..."
mkdir -p lm_training_data
mkdir -p lm_dev_data
mkdir -p lm_test_data
log "aidatatang_200zh"
gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> aidatatang_${subset}_text > lm_training_data/aidatatang_train_text
done
for subset in train dev test; do gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_dev.jsonl.gz \
gunzip -c aishell/aishell_supervisions_${subset}.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> aishell_${subset}_text > lm_dev_data/aidatatang_dev_text
done
for subset in train dev test; do gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_test.jsonl.gz \
gunzip -c aishell2/aishell2_supervisions_${subset}.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> aishell2_${subset}_text > lm_test_data/aidatatang_test_text
done
for subset in train_L train_M train_S test; do log "aishell"
gunzip -c aishell4/aishell4_supervisions_${subset}.jsonl.gz \ gunzip -c manifests/aishell/aishell_supervisions_train.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> aishell4_${subset}_text > lm_training_data/aishell_train_text
done
for subset in train test eval; do gunzip -c manifests/aishell/aishell_supervisions_dev.jsonl.gz \
gunzip -c alimeeting/alimeeting-far_supervisions_${subset}.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> alimeeting-far_${subset}_text > lm_dev_data/aishell_dev_text
done
for subset in dev_phase1 dev_phase2 test train_phase1 train_phase2; do gunzip -c manifests/aishell/aishell_supervisions_test.jsonl.gz \
gunzip -c kespeech/kespeech-asr_supervisions_${subset}.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> kespeech_${subset}_text > lm_test_data/aishell_test_text
done
for subset in train test dev; do log "aishell2"
gunzip -c magicdata/magicdata_supervisions_${subset}.jsonl.gz \ gunzip -c manifests/aishell2/aishell2_supervisions_train.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> magicdata_${subset}_text > lm_training_data/aishell2_train_text
done
for subset in train ; do gunzip -c manifests/aishell2/aishell2_supervisions_dev.jsonl.gz \
gunzip -c stcmds/stcmds_supervisions_${subset}.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> stcmds_${subset}_text > lm_dev_data/aishell2_dev_text
done
for subset in train ; do gunzip -c manifests/aishell2/aishell2_supervisions_test.jsonl.gz \
gunzip -c primewords/primewords_supervisions_${subset}.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> primewords_${subset}_text > lm_test_data/aishell2_test_text
done
for subset in train test dev ; do log "aishell4"
gunzip -c thchs30/thchs_30_supervisions_${subset}.jsonl.gz \ gunzip -c manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> thchs30_${subset}_text > lm_training_data/aishell4_train_L_text
done
for subset in L DEV TEST_MEETING TEST_NET ; do gunzip -c manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \
gunzip -c wenetspeech/wenetspeech_supervisions_${subset}.jsonl.gz \
| jq '.text' \ | jq '.text' \
| sed 's/"//g' \ | sed 's/"//g' \
| ../../local/tokenize_for_lm_training.py -t "char" \ | ../local/tokenize_for_lm_training.py -t "char" \
> wenetspeech_${subset}_text > lm_training_data/aishell4_train_M_text
gunzip -c manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/aishell4_train_S_text
gunzip -c manifests/aishell4/aishell4_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/aishell4_test_text
log "alimeeting"
gunzip -c manifests/alimeeting/alimeeting-far_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/alimeeting-far_train_text
gunzip -c manifests/alimeeting/alimeeting-far_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/alimeeting-far_test_text
gunzip -c manifests/alimeeting/alimeeting-far_supervisions_eval.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/alimeeting-far_eval_text
log "kespeech"
gunzip -c manifests/kespeech/kespeech-asr_supervisions_dev_phase1.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/kespeech_dev_phase1_text
gunzip -c manifests/kespeech/kespeech-asr_supervisions_dev_phase2.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/kespeech_dev_phase2_text
gunzip -c manifests/kespeech/kespeech-asr_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/kespeech_test_text
gunzip -c manifests/kespeech/kespeech-asr_supervisions_train_phase1.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/kespeech_train_phase1_text
gunzip -c manifests/kespeech/kespeech-asr_supervisions_train_phase2.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/kespeech_train_phase2_text
log "magicdata"
gunzip -c manifests/magicdata/magicdata_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/magicdata_train_text
gunzip -c manifests/magicdata/magicdata_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/magicdata_test_text
gunzip -c manifests/magicdata/magicdata_supervisions_dev.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/magicdata_dev_text
log "stcmds"
gunzip -c manifests/stcmds/stcmds_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/stcmds_train_text
log "primewords"
gunzip -c manifests/primewords/primewords_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/primewords_train_text
log "thchs30"
gunzip -c manifests/thchs30/thchs_30_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/thchs30_train_text
gunzip -c manifests/thchs30/thchs_30_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/thchs30_test_text
gunzip -c manifests/thchs30/thchs_30_supervisions_dev.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/thchs30_dev_text
log "wenetspeech"
gunzip -c manifests/wenetspeech/wenetspeech_supervisions_L.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/wenetspeech_L_text
gunzip -c manifests/wenetspeech/wenetspeech_supervisions_DEV.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/wenetspeech_DEV_text
gunzip -c manifests/wenetspeech/wenetspeech_supervisions_TEST_MEETING.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/wenetspeech_TEST_MEETING_text
gunzip -c manifests/wenetspeech/wenetspeech_supervisions_TEST_NET.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/wenetspeech_TEST_NET_text
for f in aidatatang_train_text aishell2_train_text aishell4_train_L_text aishell4_train_M_text aishell4_train_S_text aishell_train_text alimeeting-far_train_text kespeech_train_phase1_text kespeech_train_phase2_text magicdata_train_text primewords_train_text stcmds_train_text thchs30_train_text wenetspeech_L_text; do
cat lm_training_data/$f >> lm_training_data/lm_training_text
done done
cat aidatatang_train_text aishell2_train_text aishell4_train_L_text \ for f in aidatatang_test_text aishell4_test_text alimeeting-far_test_text thchs30_test_text wenetspeech_TEST_NET_text aishell2_test_text aishell_test_text kespeech_test_text magicdata_test_text wenetspeech_TEST_MEETING_text; do
aishell4_train_M_text aishell4_train_S_text aishell_train_text \ cat lm_test_data/$f >> lm_test_data/lm_test_text
alimeeting-far_train_text kespeech_train_phase1_text kespeech_train_phase2_text \ done
magicdata_train_text primewords_train_text stcmds_train_text \
thchs30_train_text wenetspeech_L_text > lm_training_text for f in aidatatang_dev_text aishell_dev_text kespeech_dev_phase1_text thchs30_dev_text aishell2_dev_text alimeeting-far_eval_text kespeech_dev_phase2_text magicdata_dev_text wenetspeech_DEV_text; do
cat lm_dev_data/$f >> lm_dev_data/lm_dev_text
done
cd ../

View File

@ -97,6 +97,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -115,11 +116,16 @@ from beam_search import (
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_lm_rescore,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
) )
from lhotse.cut import Cut from lhotse.cut import Cut
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -212,6 +218,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- modified_beam_search_LODR
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
@ -303,6 +310,47 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=2,
help="""The order of the ngram lm.
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -315,6 +363,10 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -343,6 +395,12 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network language model.
ngram_lm:
A ngram language model
ngram_lm_scale:
The scale for the ngram language model.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -443,6 +501,51 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
lm_scale_list=lm_scale_list,
)
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)]
ans_dict = modified_beam_search_lm_rescore_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
)
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -481,6 +584,22 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}"
if params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
return {prefix: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -492,6 +611,10 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -540,8 +663,12 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -610,6 +737,7 @@ def save_results(
def main(): def main():
parser = get_parser() parser = get_parser()
AsrDataModule.add_arguments(parser) AsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -624,9 +752,18 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
@ -653,10 +790,24 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_shallow_fusion:
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -762,6 +913,54 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# only load the neural network LM if required
if params.use_shallow_fusion or params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
):
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
# only load N-gram LM when needed
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
try:
import kenlm
except ImportError:
print("Please install kenlm first. You can use")
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
print("to install it")
import sys
sys.exit(-1)
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
logging.info(f"lm filename: {ngram_file_name}")
ngram_lm = kenlm.Model(ngram_file_name)
ngram_lm_scale = None # use a list to search
elif params.decoding_method == "modified_beam_search_LODR":
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"Loading token level lm: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
@ -779,6 +978,18 @@ def main():
decoding_graph = None decoding_graph = None
word_table = None word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -813,6 +1024,10 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
) )
save_results( save_results(

View File

@ -31,7 +31,7 @@ from pathlib import Path
import k2 import k2
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm.auto import tqdm
def get_args(): def get_args():