mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Shallow fusion for Aishell (#954)
* add shallow fusion and LODR for aishell * update RESULTS * add save by iterations
This commit is contained in:
parent
46bf6df62f
commit
d337398d29
@ -15,6 +15,8 @@ It uses pruned RNN-T.
|
|||||||
|------------------------|------|------|---------------------------------------|
|
|------------------------|------|------|---------------------------------------|
|
||||||
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
|
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
|
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
|
| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
|
| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
|
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
|
|
||||||
Training command is:
|
Training command is:
|
||||||
@ -73,6 +75,78 @@ for epoch in 29; do
|
|||||||
done
|
done
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
|
||||||
|
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
|
||||||
|
please use the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# download pre-trained model
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
|
||||||
|
|
||||||
|
aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/
|
||||||
|
|
||||||
|
pushd ${aishell_exp}/exp
|
||||||
|
ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
# download RNN LM
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm
|
||||||
|
rnnlm_dir=icefall-aishell-rnn-lm
|
||||||
|
|
||||||
|
# RNNLM shallow fusion
|
||||||
|
for lm_scale in $(seq 0.26 0.02 0.34); do
|
||||||
|
python ./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--lang-dir ${aishell_exp}/data/lang_char \
|
||||||
|
--exp-dir ${aishell_exp}/exp \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||||
|
--use-shallow-fusion 1 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-exp-dir ${rnnlm_dir}/exp \
|
||||||
|
--lm-epoch 99 \
|
||||||
|
--lm-scale $lm_scale \
|
||||||
|
--lm-avg 1 \
|
||||||
|
--rnn-lm-embedding-dim 2048 \
|
||||||
|
--rnn-lm-hidden-dim 2048 \
|
||||||
|
--rnn-lm-num-layers 2 \
|
||||||
|
--lm-vocab-size 4336
|
||||||
|
done
|
||||||
|
|
||||||
|
# RNNLM Low-order density ratio (LODR) with a 2-gram
|
||||||
|
|
||||||
|
cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt
|
||||||
|
|
||||||
|
for lm_scale in 0.48; do
|
||||||
|
for LODR_scale in -0.28; do
|
||||||
|
python ./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--lang-dir ${aishell_exp}/data/lang_char \
|
||||||
|
--exp-dir ${aishell_exp}/exp \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--decoding-method modified_beam_search_LODR \
|
||||||
|
--use-shallow-fusion 1 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-exp-dir ${rnnlm_dir}/exp \
|
||||||
|
--lm-epoch 99 \
|
||||||
|
--lm-scale $lm_scale \
|
||||||
|
--lm-avg 1 \
|
||||||
|
--rnn-lm-embedding-dim 2048 \
|
||||||
|
--rnn-lm-hidden-dim 2048 \
|
||||||
|
--rnn-lm-num-layers 2 \
|
||||||
|
--lm-vocab-size 4336 \
|
||||||
|
--tokens-ngram 2 \
|
||||||
|
--backoff-id 4336 \
|
||||||
|
--ngram-lm-scale $LODR_scale
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
Pretrained models, training logs, decoding logs, and decoding results
|
Pretrained models, training logs, decoding logs, and decoding results
|
||||||
are available at
|
are available at
|
||||||
<https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20>
|
<https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20>
|
||||||
|
0
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file → Executable file
0
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file → Executable file
1
egs/aishell/ASR/local/sort_lm_training_data.py
Symbolic link
1
egs/aishell/ASR/local/sort_lm_training_data.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/sort_lm_training_data.py
|
@ -230,12 +230,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
|
if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
|
||||||
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
|
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# training words
|
||||||
./local/prepare_char_lm_training_data.py \
|
./local/prepare_char_lm_training_data.py \
|
||||||
--lang-char data/lang_char \
|
--lang-char data/lang_char \
|
||||||
--lm-data $dl_dir/lm/aishell-train-word.txt \
|
--lm-data $dl_dir/lm/aishell-train-word.txt \
|
||||||
--lm-archive $out_dir/lm_data.pt
|
--lm-archive $out_dir/lm_data.pt
|
||||||
|
|
||||||
|
# valid words
|
||||||
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
|
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
|
||||||
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
|
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
|
||||||
@ -249,6 +251,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
--lm-data $dl_dir/lm/aishell-valid-word.txt \
|
--lm-data $dl_dir/lm/aishell-valid-word.txt \
|
||||||
--lm-archive $out_dir/lm_data_valid.pt
|
--lm-archive $out_dir/lm_data_valid.pt
|
||||||
|
|
||||||
|
# test words
|
||||||
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
|
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
|
||||||
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
|
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
|
||||||
@ -303,9 +306,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
|||||||
--hidden-dim 512 \
|
--hidden-dim 512 \
|
||||||
--num-layers 2 \
|
--num-layers 2 \
|
||||||
--batch-size 400 \
|
--batch-size 400 \
|
||||||
--exp-dir rnnlm_char/exp \
|
--exp-dir rnnlm_char/exp_aishell1_small \
|
||||||
--lm-data data/lm_training_char/sorted_lm_data.pt \
|
--lm-data data/lm_char/sorted_lm_data_aishell1.pt \
|
||||||
--lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \
|
--lm-data-valid data/lm_char/sorted_lm_data_valid.pt \
|
||||||
--vocab-size 4336 \
|
--vocab-size 4336 \
|
||||||
--master-port 12345
|
--master-port 12345
|
||||||
fi
|
fi
|
||||||
|
@ -54,6 +54,40 @@ Usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) modified beam search (with LM shallow fusion)
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||||
|
--beam-size 4 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-scale 0.3 \
|
||||||
|
--lm-exp-dir /path/to/LM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1
|
||||||
|
|
||||||
|
(6) modified beam search with LM shallow fusion + LODR
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--decoding-method modified_beam_search_LODR \
|
||||||
|
--beam-size 4 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-scale 0.48 \
|
||||||
|
--lm-exp-dir /path/to/LM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1
|
||||||
|
--tokens-ngram 2 \
|
||||||
|
--ngram-lm-scale -0.28 \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -74,9 +108,12 @@ from beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
|
modified_beam_search_lm_shallow_fusion,
|
||||||
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -212,6 +249,60 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-shallow-fusion",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Use neural network LM for shallow fusion.
|
||||||
|
If you want to use LODR, you will also need to set this to true
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="""The scale of the neural network LM
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-ngram",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="""Token Ngram used for rescoring.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""ID of the backoff symbol.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -223,6 +314,9 @@ def decode_one_batch(
|
|||||||
token_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -287,6 +381,24 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||||
|
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search_LODR":
|
||||||
|
hyp_tokens = modified_beam_search_LODR(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LODR_lm=ngram_lm,
|
||||||
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hyp_tokens = []
|
hyp_tokens = []
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
@ -334,6 +446,9 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
token_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -379,6 +494,9 @@ def decode_dataset(
|
|||||||
token_table=token_table,
|
token_table=token_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -445,6 +563,7 @@ def save_results(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
AsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
args.lang_dir = Path(args.lang_dir)
|
||||||
@ -458,6 +577,8 @@ def main():
|
|||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_LODR",
|
||||||
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -479,6 +600,19 @@ def main():
|
|||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
if "ngram" in params.decoding_method:
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
if params.lm_type == "rnn":
|
||||||
|
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
|
||||||
|
elif params.lm_type == "transformer":
|
||||||
|
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
|
if "LODR" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -588,6 +722,35 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
# only load N-gram LM when needed
|
||||||
|
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||||
|
lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt"
|
||||||
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
|
ngram_lm = NgramLm(
|
||||||
|
lm_filename,
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
ngram_lm_scale = params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
ngram_lm = None
|
||||||
|
ngram_lm_scale = None
|
||||||
|
|
||||||
|
# only load the neural network LM if doing shallow fusion
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
LM = LmScorer(
|
||||||
|
lm_type=params.lm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
LM.to(device)
|
||||||
|
LM.eval()
|
||||||
|
|
||||||
|
else:
|
||||||
|
LM = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -610,6 +773,9 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
token_table=lexicon.token_table,
|
token_table=lexicon.token_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -1 +0,0 @@
|
|||||||
/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
|
|
@ -550,7 +550,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LM=LM,
|
LM=LM,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -561,7 +560,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
@ -1863,7 +1863,6 @@ def modified_beam_search_LODR(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
LODR_lm: NgramLm,
|
LODR_lm: NgramLm,
|
||||||
LODR_lm_scale: float,
|
LODR_lm_scale: float,
|
||||||
LM: LmScorer,
|
LM: LmScorer,
|
||||||
@ -1883,8 +1882,6 @@ def modified_beam_search_LODR(
|
|||||||
encoder_out_lens (torch.Tensor):
|
encoder_out_lens (torch.Tensor):
|
||||||
A 1-D tensor of shape (N,), containing the number of
|
A 1-D tensor of shape (N,), containing the number of
|
||||||
valid frames in encoder_out before padding.
|
valid frames in encoder_out before padding.
|
||||||
sp:
|
|
||||||
Sentence piece generator.
|
|
||||||
LODR_lm:
|
LODR_lm:
|
||||||
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
||||||
LODR_lm_scale:
|
LODR_lm_scale:
|
||||||
@ -1912,7 +1909,7 @@ def modified_beam_search_LODR(
|
|||||||
)
|
)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
sos_id = sp.piece_to_id("<sos/eos>")
|
sos_id = getattr(LM, "sos_id", 1)
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
@ -2137,7 +2134,6 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
LM: LmScorer,
|
LM: LmScorer,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
@ -2176,7 +2172,7 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
sos_id = sp.piece_to_id("<sos/eos>")
|
sos_id = getattr(LM, "sos_id", 1)
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
|
@ -675,7 +675,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LM=LM,
|
LM=LM,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -686,7 +685,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
@ -586,7 +586,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LM=LM,
|
LM=LM,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -597,7 +596,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
@ -533,7 +533,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LM=LM,
|
LM=LM,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -544,7 +543,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
@ -40,8 +40,8 @@ from tqdm import tqdm
|
|||||||
# and 'data()' is only supported in static graph mode. So if you
|
# and 'data()' is only supported in static graph mode. So if you
|
||||||
# want to use this api, should call 'paddle.enable_static()' before
|
# want to use this api, should call 'paddle.enable_static()' before
|
||||||
# this api to enter static graph mode.
|
# this api to enter static graph mode.
|
||||||
paddle.enable_static()
|
# paddle.enable_static()
|
||||||
paddle.disable_signal_handler()
|
# paddle.disable_signal_handler()
|
||||||
jieba.enable_paddle()
|
jieba.enable_paddle()
|
||||||
|
|
||||||
|
|
||||||
|
@ -261,3 +261,107 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
|||||||
log "Stage 18: Compile LG"
|
log "Stage 18: Compile LG"
|
||||||
python ./local/compile_lg.py --lang-dir $lang_char_dir
|
python ./local/compile_lg.py --lang-dir $lang_char_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# prepare RNNLM data
|
||||||
|
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||||
|
log "Stage 19: Prepare LM training data"
|
||||||
|
|
||||||
|
log "Processing char based data"
|
||||||
|
text_out_dir=data/lm_char
|
||||||
|
|
||||||
|
mkdir -p $text_out_dir
|
||||||
|
|
||||||
|
log "Genearating training text data"
|
||||||
|
|
||||||
|
if [ ! -f $text_out_dir/lm_data.pt ]; then
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $lang_char_dir/text_words_segmentation \
|
||||||
|
--lm-archive $text_out_dir/lm_data.pt
|
||||||
|
fi
|
||||||
|
|
||||||
|
log "Generating DEV text data"
|
||||||
|
# prepare validation text data
|
||||||
|
if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then
|
||||||
|
valid_text=${text_out_dir}/
|
||||||
|
|
||||||
|
gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \
|
||||||
|
| jq '.text' | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $text_out_dir/valid_text
|
||||||
|
|
||||||
|
python3 ./local/text2segments.py \
|
||||||
|
--num-process $nj \
|
||||||
|
--input-file $text_out_dir/valid_text \
|
||||||
|
--output-file $text_out_dir/valid_text_words_segmentation
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $text_out_dir/valid_text_words_segmentation \
|
||||||
|
--lm-archive $text_out_dir/lm_data_valid.pt
|
||||||
|
|
||||||
|
# prepare TEST text data
|
||||||
|
if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then
|
||||||
|
log "Prepare text for test set."
|
||||||
|
for test_set in TEST_MEETING TEST_NET; do
|
||||||
|
gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \
|
||||||
|
| jq '.text' | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text
|
||||||
|
|
||||||
|
python3 ./local/text2segments.py \
|
||||||
|
--num-process $nj \
|
||||||
|
--input-file $text_out_dir/${test_set}_text \
|
||||||
|
--output-file $text_out_dir/${test_set}_text_words_segmentation
|
||||||
|
done
|
||||||
|
|
||||||
|
cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $text_out_dir/test_text_words_segmentation \
|
||||||
|
--lm-archive $text_out_dir/lm_data_test.pt
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
# sort RNNLM data
|
||||||
|
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
|
text_out_dir=data/lm_char
|
||||||
|
|
||||||
|
log "Sort lm data"
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $text_out_dir/lm_data.pt \
|
||||||
|
--out-lm-data $text_out_dir/sorted_lm_data.pt \
|
||||||
|
--out-statistics $text_out_dir/statistics.txt
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $text_out_dir/lm_data_valid.pt \
|
||||||
|
--out-lm-data $text_out_dir/sorted_lm_data-valid.pt \
|
||||||
|
--out-statistics $text_out_dir/statistics-valid.txt
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $text_out_dir/lm_data_test.pt \
|
||||||
|
--out-lm-data $text_out_dir/sorted_lm_data-test.pt \
|
||||||
|
--out-statistics $text_out_dir/statistics-test.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||||
|
log "Stage 21: Train RNN LM model"
|
||||||
|
python ../../../icefall/rnn_lm/train.py \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--embedding-dim 2048 \
|
||||||
|
--hidden-dim 2048 \
|
||||||
|
--num-layers 2 \
|
||||||
|
--batch-size 400 \
|
||||||
|
--exp-dir rnnlm_char/exp \
|
||||||
|
--lm-data data/lm_char/sorted_lm_data.pt \
|
||||||
|
--lm-data-valid data/lm_char/sorted_lm_data-valid.pt \
|
||||||
|
--vocab-size 4336 \
|
||||||
|
--master-port 12340
|
||||||
|
fi
|
@ -2,6 +2,7 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -91,6 +92,22 @@ When training with the L subset, the streaming usage:
|
|||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
--decode-chunk-size 16 \
|
--decode-chunk-size 16 \
|
||||||
--left-context 64
|
--left-context 64
|
||||||
|
|
||||||
|
(4) modified beam search with RNNLM shallow fusion
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch 35 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||||
|
--beam-size 4 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-scale 0.3 \
|
||||||
|
--lm-exp-dir /path/to/LM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -111,9 +128,12 @@ from beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
|
modified_beam_search_lm_shallow_fusion,
|
||||||
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -224,6 +244,16 @@ def get_parser():
|
|||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
@ -277,6 +307,50 @@ def get_parser():
|
|||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-shallow-fusion",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Use neural network LM for shallow fusion.
|
||||||
|
If you want to use LODR, you will also need to set this to true
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="""The scale of the neural network LM
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-ngram",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Token Ngram used for rescoring.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring, or LODR
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""ID of the backoff symbol.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -288,6 +362,9 @@ def decode_one_batch(
|
|||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -374,6 +451,28 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||||
|
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "modified_beam_search_LODR":
|
||||||
|
hyp_tokens = modified_beam_search_LODR(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LODR_lm=ngram_lm,
|
||||||
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -419,6 +518,9 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -432,6 +534,8 @@ def decode_dataset(
|
|||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search.
|
||||||
|
LM:
|
||||||
|
A neural network LM, used during shallow fusion
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -449,7 +553,7 @@ def decode_dataset(
|
|||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 100
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -463,6 +567,9 @@ def decode_dataset(
|
|||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -524,6 +631,7 @@ def save_results(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
WenetSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -535,6 +643,8 @@ def main():
|
|||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
|
"modified_beam_search_LODR",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -549,6 +659,22 @@ def main():
|
|||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if "ngram" in params.decoding_method:
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
if params.lm_type == "rnn":
|
||||||
|
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
|
||||||
|
elif params.lm_type == "transformer":
|
||||||
|
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
|
if "LODR" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -558,6 +684,7 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
@ -652,6 +779,37 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
# only load N-gram LM when needed
|
||||||
|
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||||
|
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||||
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
|
ngram_lm = NgramLm(
|
||||||
|
str(params.lang_dir / lm_filename),
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
ngram_lm_scale = params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
ngram_lm = None
|
||||||
|
ngram_lm_scale = None
|
||||||
|
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
# only load the neural network LM if doing shallow fusion
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
LM = LmScorer(
|
||||||
|
lm_type=params.lm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
LM.to(device)
|
||||||
|
LM.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in LM.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
else:
|
||||||
|
LM = None
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
@ -684,6 +842,9 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
save_results(
|
save_results(
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -50,7 +50,7 @@ class LmScorer(torch.nn.Module):
|
|||||||
def add_arguments(cls, parser):
|
def add_arguments(cls, parser):
|
||||||
# LM general arguments
|
# LM general arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size",
|
"--lm-vocab-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=500,
|
default=500,
|
||||||
)
|
)
|
||||||
|
@ -33,7 +33,7 @@ import torch
|
|||||||
from dataset import get_dataloader
|
from dataset import get_dataloader
|
||||||
from model import RnnLmModel
|
from model import RnnLmModel
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -49,6 +49,7 @@ def get_parser():
|
|||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
@ -58,6 +59,16 @@ def get_parser():
|
|||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -154,7 +165,14 @@ def main():
|
|||||||
|
|
||||||
params = AttributeDict(vars(args))
|
params = AttributeDict(vars(args))
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log-ppl/")
|
if params.iter > 0:
|
||||||
|
setup_logger(
|
||||||
|
f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
setup_logger(
|
||||||
|
f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
)
|
||||||
logging.info("Computing perplexity started")
|
logging.info("Computing perplexity started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -173,19 +191,39 @@ def main():
|
|||||||
tie_weights=params.tie_weights,
|
tie_weights=params.tie_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.iter > 0:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
for i in range(start, params.epoch + 1):
|
for i in range(start, params.epoch + 1):
|
||||||
if start >= 0:
|
if i >= 0:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
num_param_requires_grad = sum(
|
num_param_requires_grad = sum(
|
||||||
|
@ -25,7 +25,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from model import RnnLmModel
|
from model import RnnLmModel
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import AttributeDict, load_averaged_model, str2bool
|
from icefall.utils import AttributeDict, load_averaged_model, str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -51,6 +51,16 @@ def get_parser():
|
|||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size",
|
"--vocab-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -133,11 +143,36 @@ def main():
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
model = load_averaged_model(
|
start = params.epoch - params.avg + 1
|
||||||
params.exp_dir, model, params.epoch, params.avg, device
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
)
|
)
|
||||||
|
|
||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
|
@ -49,6 +49,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -178,6 +179,33 @@ def get_parser():
|
|||||||
help="The seed for random generators intended for reproducibility",
|
help="The seed for random generators intended for reproducibility",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr",
|
||||||
|
type=float,
|
||||||
|
default=1e-3,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sent-len",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Maximum number of tokens in a sentence. This is used
|
||||||
|
to adjust batch-size dynamically""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every-n",
|
||||||
|
type=int,
|
||||||
|
default=2000,
|
||||||
|
help="""Save checkpoint after processing this number of batches"
|
||||||
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
|
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||||
|
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||||
|
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -190,16 +218,15 @@ def get_params() -> AttributeDict:
|
|||||||
"sos_id": 1,
|
"sos_id": 1,
|
||||||
"eos_id": 1,
|
"eos_id": 1,
|
||||||
"blank_id": 0,
|
"blank_id": 0,
|
||||||
"lr": 1e-3,
|
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 200,
|
"log_interval": 100,
|
||||||
"reset_interval": 2000,
|
"reset_interval": 2000,
|
||||||
"valid_interval": 5000,
|
"valid_interval": 200,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -382,6 +409,7 @@ def train_one_epoch(
|
|||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Train the model for one epoch.
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
@ -430,6 +458,19 @@ def train_one_epoch(
|
|||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
if (
|
||||||
|
params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
|
):
|
||||||
|
save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
global_batch_idx=params.batch_idx_train,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
# Note: "frames" here means "num_tokens"
|
# Note: "frames" here means "num_tokens"
|
||||||
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
||||||
@ -580,6 +621,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user