From a0063829418c6bdd5b0d5c090bb11c0738e40140 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 23 Oct 2023 13:29:31 +0800 Subject: [PATCH 01/26] Create prepare_lm_data.sh --- egs/multi_zh-hans/ASR/prepare_lm_data.sh | 93 ++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 egs/multi_zh-hans/ASR/prepare_lm_data.sh diff --git a/egs/multi_zh-hans/ASR/prepare_lm_data.sh b/egs/multi_zh-hans/ASR/prepare_lm_data.sh new file mode 100644 index 000000000..241fc65e7 --- /dev/null +++ b/egs/multi_zh-hans/ASR/prepare_lm_data.sh @@ -0,0 +1,93 @@ +for subset in train dev test; do + gunzip -c aidatatang_200zh/aidatatang_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > aidatatang_${subset}_text +done + +for subset in train dev test; do + gunzip -c aishell/aishell_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > aishell_${subset}_text +done + +for subset in train dev test; do + gunzip -c aishell2/aishell2_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > aishell2_${subset}_text +done + +for subset in train_L train_M train_S test; do + gunzip -c aishell4/aishell4_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > aishell4_${subset}_text +done + +for subset in train test eval; do + gunzip -c alimeeting/alimeeting-far_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > alimeeting-far_${subset}_text +done + +for subset in dev_phase1 dev_phase2 test train_phase1 train_phase2; do + gunzip -c kespeech/kespeech-asr_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > kespeech_${subset}_text +done + +for subset in train test dev; do + gunzip -c magicdata/magicdata_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > magicdata_${subset}_text +done + +for subset in train ; do + gunzip -c stcmds/stcmds_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > stcmds_${subset}_text +done + +for subset in train ; do + gunzip -c primewords/primewords_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > primewords_${subset}_text +done + +for subset in train test dev ; do + gunzip -c thchs30/thchs_30_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > thchs30_${subset}_text +done + +for subset in L DEV TEST_MEETING TEST_NET ; do + gunzip -c wenetspeech/wenetspeech_supervisions_${subset}.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../../local/tokenize_for_lm_training.py -t "char" \ + > wenetspeech_${subset}_text +done + +cat aidatatang_train_text aishell2_train_text aishell4_train_L_text \ + aishell4_train_M_text aishell4_train_S_text aishell_train_text \ + alimeeting-far_train_text kespeech_train_phase1_text kespeech_train_phase2_text \ + magicdata_train_text primewords_train_text stcmds_train_text \ + thchs30_train_text wenetspeech_L_text > lm_training_text From 1a11440014eb0313f5c03f4fd6c9d954bba6c23b Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 09:57:57 +0800 Subject: [PATCH 02/26] minor updates --- .../ASR/local/prepare_for_bpe_model.py | 2 - .../ASR/local/prepare_lm_training_data.py | 161 ++++++++++++++++++ .../ASR/local/sort_lm_training_data.py | 1 + .../ASR/zipformer/asr_datamodule.py | 11 +- .../ASR/zipformer/multi_dataset.py | 4 +- 5 files changed, 166 insertions(+), 13 deletions(-) create mode 100755 egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py create mode 120000 egs/multi_zh-hans/ASR/local/sort_lm_training_data.py diff --git a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py index 020800c15..2b8708272 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py +++ b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py @@ -22,8 +22,6 @@ import argparse from pathlib import Path -from tqdm.auto import tqdm - from icefall.utils import tokenize_by_CJK_char diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py new file mode 100755 index 000000000..45cc4b8d7 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `bpe.model` and a text file such as +./download/lm/librispeech-lm-norm.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_bpe_500. The format is as follows: + +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the following format: + + 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 + containing the BPE representations of each word, indexed by + integer word ID. (These integer word IDS are present in + 'lm_data'). The sentencepiece object can be used to turn the + words and BPE units into string form. + 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype + torch.int32 containing all the sentences, as word-ids (we don't + output the string form of this directly but it can be worked out + together with 'words' and the bpe.model). + 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing + number of BPE tokens of each sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe-model", + default="data/lang_bpe_2000/bpe.model", + type=str, + help="Input BPE model, e.g. data/bpe_500/bpe.model", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/pb.train.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of BPE pieces. + word2index = dict() + + word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + num_lines_in_total = None + step = 500000 + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + line = tokenize_by_CJK_char(line) + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed/num_lines_in_total*100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_bpe = sp.encode(w) + word2index[w] = len(word2bpe) + word2bpe.append(w_bpe) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2bpe) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py b/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..1d6ccbe33 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 02cfa1346..ae1264659 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -19,25 +19,20 @@ import argparse import inspect import logging -from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( CutConcatenate, CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, - PrecomputedFeatures, SimpleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py index b1920e62e..ac90dc73e 100644 --- a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -15,11 +15,9 @@ # limitations under the License. -import glob import logging -import re from pathlib import Path -from typing import Dict, List +from typing import Dict import lhotse from lhotse import CutSet, load_manifest_lazy From 94f963baf8e0ecaa51040857c5b6469fd9925a36 Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 10:05:29 +0800 Subject: [PATCH 03/26] Update prepare_lm_training_data.py --- egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py index 45cc4b8d7..e931086fb 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -92,7 +92,6 @@ def main(): word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. sentences = [] # Will be a list-of-list-of-int, representing word-ids. - num_lines_in_total = None step = 500000 processed = 0 @@ -105,10 +104,7 @@ def main(): break if step and processed % step == 0: - logging.info( - f"Processed number of lines: {processed} " - f"({processed/num_lines_in_total*100: .3f}%)" - ) + logging.info(f"Processed number of lines: {processed} ") processed += 1 line_words = line.split() From 86c3dbec0ed9e1581d5ebbbcb8851951083e4339 Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 10:07:32 +0800 Subject: [PATCH 04/26] Update prepare_lm_training_data.py --- egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py index e931086fb..3c1cc295d 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -101,7 +101,7 @@ def main(): line = f.readline() line = tokenize_by_CJK_char(line) if line == "": - break + continue if step and processed % step == 0: logging.info(f"Processed number of lines: {processed} ") From 7f53f59776b59a326401b0110f30934d0c2abbca Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 10:14:08 +0800 Subject: [PATCH 05/26] Update prepare_lm_training_data.py --- egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py index 3c1cc295d..e931086fb 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -101,7 +101,7 @@ def main(): line = f.readline() line = tokenize_by_CJK_char(line) if line == "": - continue + break if step and processed % step == 0: logging.info(f"Processed number of lines: {processed} ") From 403e2e52ac3038d61e362dcb9dbd616ad1280648 Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 10:20:10 +0800 Subject: [PATCH 06/26] Update prepare_lm_training_data.py --- egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py index e931086fb..41638cd82 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -99,9 +99,9 @@ def main(): with open(args.lm_data) as f: while True: line = f.readline() - line = tokenize_by_CJK_char(line) if line == "": break + line = tokenize_by_CJK_char(line) if step and processed % step == 0: logging.info(f"Processed number of lines: {processed} ") From d29efb7345db9dd61dd98c8116e4aa704226783f Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 10:20:56 +0800 Subject: [PATCH 07/26] Update prepare_lm_training_data.py --- egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py index 41638cd82..95ee982bd 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -102,6 +102,8 @@ def main(): if line == "": break line = tokenize_by_CJK_char(line) + if line == "": + continue if step and processed % step == 0: logging.info(f"Processed number of lines: {processed} ") From 817413f8990f928d0cd2a2870cdedcb967904503 Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 10:53:34 +0800 Subject: [PATCH 08/26] minor updates --- egs/multi_zh-hans/ASR/prepare.sh | 69 +++++++++++++++++++++++ egs/ptb/LM/local/sort_lm_training_data.py | 3 +- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index c09b9c1de..abd8cafb5 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -370,4 +370,73 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then done fi +if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then + log "Stage 16: Prepare LM data" + + ./prepare_lm_data.sh + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + + mkdir $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data ./data/lm_training_data/lm_training_text \ + --lm-archive $out_dir/lm_training_data.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data ./data/lm_dev_data/lm_dev_text \ + --lm-archive $out_dir/lm_dev_data.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data ./data/lm_test_data/lm_test_text \ + --lm-archive $out_dir/lm_test_data.pt + done +fi + +if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then + log "Stage 17: Sort LM data" + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_training_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_dev_data.pt \ + --out-lm-data $out_dir/sorted_lm_data-dev.pt \ + --out-statistics $out_dir/statistics-dev.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_test_data.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi + +if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then + log "Stage 18: Train RNN LM model" + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 1 \ + --use-fp16 0 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 3 \ + --batch-size 400 \ + --exp-dir rnnlm_bpe_${vocab_size}/exp \ + --lm-data $out_dir/sorted_lm_data.pt \ + --lm-data-valid $out_dir/sorted_lm_data-dev.pt \ + --vocab-size $vocab_size \ + --master-port 12345 + done +fi diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py index bed3856e4..31f36691f 100755 --- a/egs/ptb/LM/local/sort_lm_training_data.py +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -31,6 +31,7 @@ from pathlib import Path import k2 import numpy as np import torch +from tqdm import tqdm def get_args(): @@ -87,7 +88,7 @@ def main(): ) cur = None - for i in range(num_sentences): + for i in tqdm(range(num_sentences)): word_ids = sorted_sentences[i] token_ids = words2bpe[word_ids] if isinstance(token_ids, k2.RaggedTensor): From 3f89cb380aaaea0bd45d87d2caa8f5c47f59e14a Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 11:36:36 +0800 Subject: [PATCH 09/26] minor updates --- egs/multi_zh-hans/ASR/prepare.sh | 5 +- egs/multi_zh-hans/ASR/prepare_lm_data.sh | 310 ++++++++++++++++------- 2 files changed, 225 insertions(+), 90 deletions(-) diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index abd8cafb5..0704451f5 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -426,7 +426,7 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then out_dir=data/lm_training_bpe_${vocab_size} python ../../../icefall/rnn_lm/train.py \ --start-epoch 0 \ - --world-size 1 \ + --world-size 2 \ --use-fp16 0 \ --embedding-dim 2048 \ --hidden-dim 2048 \ @@ -435,8 +435,7 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then --exp-dir rnnlm_bpe_${vocab_size}/exp \ --lm-data $out_dir/sorted_lm_data.pt \ --lm-data-valid $out_dir/sorted_lm_data-dev.pt \ - --vocab-size $vocab_size \ - --master-port 12345 + --vocab-size $vocab_size done fi diff --git a/egs/multi_zh-hans/ASR/prepare_lm_data.sh b/egs/multi_zh-hans/ASR/prepare_lm_data.sh index 241fc65e7..289c5c52c 100644 --- a/egs/multi_zh-hans/ASR/prepare_lm_data.sh +++ b/egs/multi_zh-hans/ASR/prepare_lm_data.sh @@ -1,93 +1,229 @@ -for subset in train dev test; do - gunzip -c aidatatang_200zh/aidatatang_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > aidatatang_${subset}_text +cd data/ + +log "Preparing LM data..." +mkdir -p lm_training_data +mkdir -p lm_dev_data +mkdir -p lm_test_data + +log "aidatatang_200zh" +gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aidatatang_train_text + +gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/aidatatang_dev_text + +gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/aidatatang_test_text + +log "aishell" +gunzip -c manifests/aishell/aishell_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell_train_text + +gunzip -c manifests/aishell/aishell_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/aishell_dev_text + +gunzip -c manifests/aishell/aishell_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/aishell_test_text + +log "aishell2" +gunzip -c manifests/aishell2/aishell2_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell2_train_text + +gunzip -c manifests/aishell2/aishell2_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/aishell2_dev_text + +gunzip -c manifests/aishell2/aishell2_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/aishell2_test_text + +log "aishell4" +gunzip -c manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell4_train_L_text + +gunzip -c manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell4_train_M_text + +gunzip -c manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell4_train_S_text + +gunzip -c manifests/aishell4/aishell4_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/aishell4_test_text + +log "alimeeting" +gunzip -c manifests/alimeeting/alimeeting-far_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/alimeeting-far_train_text + +gunzip -c manifests/alimeeting/alimeeting-far_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/alimeeting-far_test_text + +gunzip -c manifests/alimeeting/alimeeting-far_supervisions_eval.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/alimeeting-far_eval_text + +log "kespeech" +gunzip -c manifests/kespeech/kespeech-asr_supervisions_dev_phase1.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/kespeech_dev_phase1_text + +gunzip -c manifests/kespeech/kespeech-asr_supervisions_dev_phase2.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/kespeech_dev_phase2_text + +gunzip -c manifests/kespeech/kespeech-asr_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/kespeech_test_text + +gunzip -c manifests/kespeech/kespeech-asr_supervisions_train_phase1.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/kespeech_train_phase1_text + +gunzip -c manifests/kespeech/kespeech-asr_supervisions_train_phase2.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/kespeech_train_phase2_text + +log "magicdata" +gunzip -c manifests/magicdata/magicdata_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/magicdata_train_text + +gunzip -c manifests/magicdata/magicdata_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/magicdata_test_text + +gunzip -c manifests/magicdata/magicdata_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/magicdata_dev_text + +log "stcmds" +gunzip -c manifests/stcmds/stcmds_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/stcmds_train_text + +log "primewords" +gunzip -c manifests/primewords/primewords_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/primewords_train_text + +log "thchs30" +gunzip -c manifests/thchs30/thchs_30_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/thchs30_train_text + +gunzip -c manifests/thchs30/thchs_30_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/thchs30_test_text + +gunzip -c manifests/thchs30/thchs_30_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/thchs30_dev_text + +log "wenetspeech" +gunzip -c manifests/wenetspeech/wenetspeech_supervisions_L.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/wenetspeech_L_text + +gunzip -c manifests/wenetspeech/wenetspeech_supervisions_DEV.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/wenetspeech_DEV_text + +gunzip -c manifests/wenetspeech/wenetspeech_supervisions_TEST_MEETING.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/wenetspeech_TEST_MEETING_text + +gunzip -c manifests/wenetspeech/wenetspeech_supervisions_TEST_NET.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/wenetspeech_TEST_NET_text + +for f in aidatatang_train_text aishell2_train_text aishell4_train_L_text aishell4_train_M_text aishell4_train_S_text aishell_train_text alimeeting-far_train_text kespeech_train_phase1_text kespeech_train_phase2_text magicdata_train_text primewords_train_text stcmds_train_text thchs30_train_text wenetspeech_L_text; do + cat lm_training_data/$f >> lm_training_data/lm_training_text done -for subset in train dev test; do - gunzip -c aishell/aishell_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > aishell_${subset}_text +for f in aidatatang_test_text aishell4_test_text alimeeting-far_test_text thchs30_test_text wenetspeech_TEST_NET_text aishell2_test_text aishell_test_text kespeech_test_text magicdata_test_text wenetspeech_TEST_MEETING_text; do + cat lm_test_data/$f >> lm_test_data/lm_test_text done -for subset in train dev test; do - gunzip -c aishell2/aishell2_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > aishell2_${subset}_text +for f in aidatatang_dev_text aishell_dev_text kespeech_dev_phase1_text thchs30_dev_text aishell2_dev_text alimeeting-far_eval_text kespeech_dev_phase2_text magicdata_dev_text wenetspeech_DEV_text; do + cat lm_dev_data/$f >> lm_dev_data/lm_dev_text done -for subset in train_L train_M train_S test; do - gunzip -c aishell4/aishell4_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > aishell4_${subset}_text -done - -for subset in train test eval; do - gunzip -c alimeeting/alimeeting-far_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > alimeeting-far_${subset}_text -done - -for subset in dev_phase1 dev_phase2 test train_phase1 train_phase2; do - gunzip -c kespeech/kespeech-asr_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > kespeech_${subset}_text -done - -for subset in train test dev; do - gunzip -c magicdata/magicdata_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > magicdata_${subset}_text -done - -for subset in train ; do - gunzip -c stcmds/stcmds_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > stcmds_${subset}_text -done - -for subset in train ; do - gunzip -c primewords/primewords_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > primewords_${subset}_text -done - -for subset in train test dev ; do - gunzip -c thchs30/thchs_30_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > thchs30_${subset}_text -done - -for subset in L DEV TEST_MEETING TEST_NET ; do - gunzip -c wenetspeech/wenetspeech_supervisions_${subset}.jsonl.gz \ - | jq '.text' \ - | sed 's/"//g' \ - | ../../local/tokenize_for_lm_training.py -t "char" \ - > wenetspeech_${subset}_text -done - -cat aidatatang_train_text aishell2_train_text aishell4_train_L_text \ - aishell4_train_M_text aishell4_train_S_text aishell_train_text \ - alimeeting-far_train_text kespeech_train_phase1_text kespeech_train_phase2_text \ - magicdata_train_text primewords_train_text stcmds_train_text \ - thchs30_train_text wenetspeech_L_text > lm_training_text +cd ../ \ No newline at end of file From aead3e0c658a10afef3fae01224e1845e9edde9e Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 11:42:28 +0800 Subject: [PATCH 10/26] Update sort_lm_training_data.py --- egs/ptb/LM/local/sort_lm_training_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py index 31f36691f..667fe600c 100755 --- a/egs/ptb/LM/local/sort_lm_training_data.py +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -31,7 +31,7 @@ from pathlib import Path import k2 import numpy as np import torch -from tqdm import tqdm +from tqdm.auto import tqdm def get_args(): From c54fdf9ff9e1cb61f04ff9d27403b57235ca7aba Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 11:42:46 +0800 Subject: [PATCH 11/26] Update prepare_lm_data.sh --- egs/multi_zh-hans/ASR/prepare_lm_data.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/prepare_lm_data.sh b/egs/multi_zh-hans/ASR/prepare_lm_data.sh index 289c5c52c..0022deda4 100644 --- a/egs/multi_zh-hans/ASR/prepare_lm_data.sh +++ b/egs/multi_zh-hans/ASR/prepare_lm_data.sh @@ -226,4 +226,4 @@ for f in aidatatang_dev_text aishell_dev_text kespeech_dev_phase1_text thchs30_d cat lm_dev_data/$f >> lm_dev_data/lm_dev_text done -cd ../ \ No newline at end of file +cd ../ From 3694e419fb1c269a4b49639a23773859bd1264ce Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 8 Nov 2023 11:52:01 +0800 Subject: [PATCH 12/26] Update prepare_lm_training_data.py --- egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py index 95ee982bd..065e4f4df 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey -# Fangjun Kuang) +# Fangjun Kuang, +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # From 4c4c26fbb7dc6c099215a4cf0edc472adda200ec Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 10:40:33 +0800 Subject: [PATCH 13/26] Update decode.py --- egs/multi_zh-hans/ASR/zipformer/decode.py | 214 ++++++++++++++++++++++ 1 file changed, 214 insertions(+) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index f501c3c30..adfd751b6 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -97,6 +97,7 @@ Usage: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -115,11 +116,16 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from lhotse.cut import Cut from multi_dataset import MultiDataset from train import add_model_arguments, get_model, get_params +from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -212,6 +218,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - modified_beam_search_LODR - fast_beam_search - fast_beam_search_nbest - fast_beam_search_nbest_oracle @@ -303,6 +310,47 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + add_model_arguments(parser) return parser @@ -315,6 +363,10 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -343,6 +395,12 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -443,6 +501,51 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) else: batch_size = encoder_out.size(0) @@ -481,6 +584,22 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -492,6 +611,10 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -540,8 +663,12 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + context_graph=context_graph, word_table=word_table, batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -624,9 +751,18 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", ) params.res_dir = params.exp_dir / params.decoding_method + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: @@ -653,10 +789,24 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -762,6 +912,54 @@ def main(): model.to(device) model.eval() + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -779,6 +977,18 @@ def main(): decoding_graph = None word_table = None + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -813,6 +1023,10 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) save_results( From 8d20337d8a72addef369891f53b139fc1d6e814f Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 10:45:22 +0800 Subject: [PATCH 14/26] Update decode.py --- egs/multi_zh-hans/ASR/zipformer/decode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index adfd751b6..710d59553 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -737,6 +737,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) From 91da99ff5215b9d5e4b53a61b00fe09f1f0e6914 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Thu, 9 Nov 2023 10:51:41 +0800 Subject: [PATCH 15/26] updated --- .../ASR/local/tokenize_for_lm_training.py | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100755 egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py diff --git a/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py b/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py new file mode 100755 index 000000000..ed7ead620 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) +# 2022 Xiaomi Corp. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import codecs +import re +import sys +from typing import List + +from pypinyin import lazy_pinyin, pinyin +from icefall.utils import str2bool, tokenize_by_CJK_char + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert raw text to tokenized text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" + ) + parser.add_argument( + "--non-lang-syms", + "-l", + default=None, + type=str, + help="list of non-linguistic symobles, e.g., etc.", + ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "--trans_type", + "-t", + type=str, + default="char", + choices=["char", "pinyin", "lazy_pinyin"], + help="""Transcript type. char/pinyin/lazy_pinyin""", + ) + return parser + + +def token2id( + texts, token_table, token_type: str = "lazy_pinyin", oov: str = "" +) -> List[List[int]]: + """Convert token to id. + Args: + texts: + The input texts, it refers to the chinese text here. + token_table: + The token table is built based on "data/lang_xxx/token.txt" + token_type: + The type of token, such as "pinyin" and "lazy_pinyin". + oov: + Out of vocabulary token. When a word(token) in the transcript + does not exist in the token list, it is replaced with `oov`. + + Returns: + The list of ids for the input texts. + """ + if texts is None: + raise ValueError("texts can't be None!") + else: + oov_id = token_table[oov] + ids: List[List[int]] = [] + for text in texts: + chars_list = list(str(text)) + if token_type == "lazy_pinyin": + text = lazy_pinyin(chars_list) + sub_ids = [ + token_table[txt] if txt in token_table else oov_id for txt in text + ] + ids.append(sub_ids) + else: # token_type = "pinyin" + text = pinyin(chars_list) + sub_ids = [ + token_table[txt[0]] if txt[0] in token_table else oov_id + for txt in text + ] + ids.append(sub_ids) + return ids + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer + ) + line = f.readline() + while line: + x = line.split() + print(" ".join(x[: args.skip_ncols]), end=" ") + a = " ".join(x[args.skip_ncols :]) # noqa E203 + + a_flat = tokenize_by_CJK_char(a) + + # print("".join(a_chars)) + print(a_flat) + line = f.readline() + + +if __name__ == "__main__": + main() From 852f5a6153e52f5d0fea30af36de36a1364db4c8 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 10:56:48 +0800 Subject: [PATCH 16/26] isort formatted --- egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py b/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py index ed7ead620..f8f5b1be5 100755 --- a/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py +++ b/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py @@ -24,6 +24,7 @@ import sys from typing import List from pypinyin import lazy_pinyin, pinyin + from icefall.utils import str2bool, tokenize_by_CJK_char is_python2 = sys.version_info[0] == 2 From 7bd260fb5a7f648f08f2065e1dddab8f4c4ec3d2 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 11:01:21 +0800 Subject: [PATCH 17/26] Update decode.py --- egs/multi_zh-hans/ASR/zipformer/decode.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 710d59553..204524a8c 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -351,6 +351,28 @@ def get_parser(): help="ID of the backoff symbol in the ngram LM", ) + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) return parser From b4d91d24accb19daf8d6379521ac9501436d0e70 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 11:02:36 +0800 Subject: [PATCH 18/26] Update asr_datamodule.py --- egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index ae1264659..7b2020309 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -24,11 +24,12 @@ from typing import Any, Dict, Optional import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest -from lhotse.dataset import ( +from lhotse.dataset import ( # noqa PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, + PrecomputedFeatures, SimpleCutSampler, SpecAugment, ) From 16499a5ef627e0843ffa89a69bff54934d709810 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 11:37:18 +0800 Subject: [PATCH 19/26] Update decode.py --- egs/multi_zh-hans/ASR/zipformer/decode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 204524a8c..2d3510fc1 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -520,6 +520,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) From 73e1237c2d5842ab0b0d3b5ab474c948fd8ff019 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 11:50:39 +0800 Subject: [PATCH 20/26] Update decode.py --- egs/multi_zh-hans/ASR/zipformer/decode.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 2d3510fc1..acb70e388 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -125,7 +125,7 @@ from lhotse.cut import Cut from multi_dataset import MultiDataset from train import add_model_arguments, get_model, get_params -from icefall import ContextGraph, LmScorer, NgramLm +from icefall import ContextGraph, LmScorer, NgramLm, tokenize_by_CJK_char from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -462,7 +462,7 @@ def decode_one_batch( max_states=params.max_states, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(tokenize_by_CJK_char(hyp).split()) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -490,7 +490,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(tokenize_by_CJK_char(hyp).split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -505,7 +505,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(tokenize_by_CJK_char(hyp).split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, @@ -513,7 +513,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(tokenize_by_CJK_char(hyp).split()) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -523,7 +523,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(tokenize_by_CJK_char(hyp).split()) elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, @@ -533,7 +533,7 @@ def decode_one_batch( LM=LM, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(tokenize_by_CJK_char(hyp).split()) elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -546,7 +546,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(tokenize_by_CJK_char(hyp).split()) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -616,7 +616,7 @@ def decode_one_batch( ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] + hyps = [tokenize_by_CJK_char(sp.decode(hyp)).split() for hyp in hyps] ans[f"{prefix}_{key}"] = hyps return ans else: @@ -678,7 +678,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [list(str(text).replace(" ", "")) for text in texts] + texts = [tokenize_by_CJK_char(text).split() for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( From a37408f663eb7392e6a7ea7937ff39be9c94501f Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 11:57:49 +0800 Subject: [PATCH 21/26] Revert "Update decode.py" This reverts commit 73e1237c2d5842ab0b0d3b5ab474c948fd8ff019. --- egs/multi_zh-hans/ASR/zipformer/decode.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index acb70e388..2d3510fc1 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -125,7 +125,7 @@ from lhotse.cut import Cut from multi_dataset import MultiDataset from train import add_model_arguments, get_model, get_params -from icefall import ContextGraph, LmScorer, NgramLm, tokenize_by_CJK_char +from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -462,7 +462,7 @@ def decode_one_batch( max_states=params.max_states, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -490,7 +490,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -505,7 +505,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, @@ -513,7 +513,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -523,7 +523,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, @@ -533,7 +533,7 @@ def decode_one_batch( LM=LM, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -546,7 +546,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(tokenize_by_CJK_char(hyp).split()) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -616,7 +616,7 @@ def decode_one_batch( ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): - hyps = [tokenize_by_CJK_char(sp.decode(hyp)).split() for hyp in hyps] + hyps = [sp.decode(hyp).split() for hyp in hyps] ans[f"{prefix}_{key}"] = hyps return ans else: @@ -678,7 +678,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [tokenize_by_CJK_char(text).split() for text in texts] + texts = [list(str(text).replace(" ", "")) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( From 39a02f7c30e82bdb16452dc3ab1e686830cd84c0 Mon Sep 17 00:00:00 2001 From: jinzr Date: Fri, 17 Nov 2023 17:06:23 +0800 Subject: [PATCH 22/26] added blank penalty --- egs/multi_zh-hans/ASR/zipformer/decode.py | 24 +++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 2d3510fc1..89e3dfa98 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -310,6 +310,18 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + parser.add_argument( "--use-shallow-fusion", type=str2bool, @@ -460,6 +472,7 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -474,6 +487,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -488,6 +502,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -503,6 +518,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -511,6 +527,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -521,6 +538,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, context_graph=context_graph, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -531,6 +549,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, LM=LM, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -544,6 +563,7 @@ def decode_one_batch( LODR_lm_scale=ngram_lm_scale, LM=LM, context_graph=context_graph, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -556,6 +576,7 @@ def decode_one_batch( beam=params.beam_size, LM=LM, lm_scale_list=lm_scale_list, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": lm_scale_list = [0.02 * i for i in range(2, 30)] @@ -568,6 +589,7 @@ def decode_one_batch( LODR_lm=ngram_lm, sp=sp, lm_scale_list=lm_scale_list, + blank_penalty=params.blank_penalty, ) else: batch_size = encoder_out.size(0) @@ -581,12 +603,14 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + blank_penalty=params.blank_penalty, ) else: raise ValueError( From 6097d7363d43ab181acbcd5b93aed6c388983017 Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 20 Dec 2023 11:31:06 +0800 Subject: [PATCH 23/26] Create convert_transcript_words_to_tokens.py --- .../ASR/local/convert_transcript_words_to_tokens.py | 1 + 1 file changed, 1 insertion(+) create mode 120000 egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py diff --git a/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py b/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py new file mode 120000 index 000000000..2ce13fd69 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file From ecfbd090af5df383a29b8c65cee2806130f0bcc1 Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 20 Dec 2023 15:05:57 +0800 Subject: [PATCH 24/26] Delete convert_transcript_words_to_tokens.py --- .../ASR/local/convert_transcript_words_to_tokens.py | 1 - 1 file changed, 1 deletion(-) delete mode 120000 egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py diff --git a/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py b/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 120000 index 2ce13fd69..000000000 --- a/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file From 2a1877486efbeeb1e6439edacf8deadb4054d15a Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 20 Dec 2023 16:51:45 +0800 Subject: [PATCH 25/26] Create convert_transcript_words_to_tokens.py --- .../convert_transcript_words_to_tokens.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100755 egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py diff --git a/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py b/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 000000000..0da036f35 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument("--oov", type=str, default="", help="The OOV word.") + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = tokenize_by_CJK_char(line).strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() From c2cb70fc22ffd0a9cb8cbe107846ef3441a7d39c Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 20 Dec 2023 18:58:40 +0800 Subject: [PATCH 26/26] Create generate_unique_lexicon.py --- egs/multi_zh-hans/ASR/local/generate_unique_lexicon.py | 1 + 1 file changed, 1 insertion(+) create mode 120000 egs/multi_zh-hans/ASR/local/generate_unique_lexicon.py diff --git a/egs/multi_zh-hans/ASR/local/generate_unique_lexicon.py b/egs/multi_zh-hans/ASR/local/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file