diff --git a/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py b/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 000000000..0da036f35 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument("--oov", type=str, default="", help="The OOV word.") + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = tokenize_by_CJK_char(line).strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/local/generate_unique_lexicon.py b/egs/multi_zh-hans/ASR/local/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py index 020800c15..2b8708272 100755 --- a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py +++ b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py @@ -22,8 +22,6 @@ import argparse from pathlib import Path -from tqdm.auto import tqdm - from icefall.utils import tokenize_by_CJK_char diff --git a/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py new file mode 100755 index 000000000..065e4f4df --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/prepare_lm_training_data.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang, +# Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `bpe.model` and a text file such as +./download/lm/librispeech-lm-norm.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_bpe_500. The format is as follows: + +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the following format: + + 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 + containing the BPE representations of each word, indexed by + integer word ID. (These integer word IDS are present in + 'lm_data'). The sentencepiece object can be used to turn the + words and BPE units into string form. + 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype + torch.int32 containing all the sentences, as word-ids (we don't + output the string form of this directly but it can be worked out + together with 'words' and the bpe.model). + 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing + number of BPE tokens of each sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe-model", + default="data/lang_bpe_2000/bpe.model", + type=str, + help="Input BPE model, e.g. data/bpe_500/bpe.model", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/pb.train.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of BPE pieces. + word2index = dict() + + word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + step = 500000 + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + line = tokenize_by_CJK_char(line) + if line == "": + continue + + if step and processed % step == 0: + logging.info(f"Processed number of lines: {processed} ") + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_bpe = sp.encode(w) + word2index[w] = len(word2bpe) + word2bpe.append(w_bpe) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2bpe) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py b/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..1d6ccbe33 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py b/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py new file mode 100755 index 000000000..f8f5b1be5 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/tokenize_for_lm_training.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) +# 2022 Xiaomi Corp. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import codecs +import re +import sys +from typing import List + +from pypinyin import lazy_pinyin, pinyin + +from icefall.utils import str2bool, tokenize_by_CJK_char + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert raw text to tokenized text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" + ) + parser.add_argument( + "--non-lang-syms", + "-l", + default=None, + type=str, + help="list of non-linguistic symobles, e.g., etc.", + ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "--trans_type", + "-t", + type=str, + default="char", + choices=["char", "pinyin", "lazy_pinyin"], + help="""Transcript type. char/pinyin/lazy_pinyin""", + ) + return parser + + +def token2id( + texts, token_table, token_type: str = "lazy_pinyin", oov: str = "" +) -> List[List[int]]: + """Convert token to id. + Args: + texts: + The input texts, it refers to the chinese text here. + token_table: + The token table is built based on "data/lang_xxx/token.txt" + token_type: + The type of token, such as "pinyin" and "lazy_pinyin". + oov: + Out of vocabulary token. When a word(token) in the transcript + does not exist in the token list, it is replaced with `oov`. + + Returns: + The list of ids for the input texts. + """ + if texts is None: + raise ValueError("texts can't be None!") + else: + oov_id = token_table[oov] + ids: List[List[int]] = [] + for text in texts: + chars_list = list(str(text)) + if token_type == "lazy_pinyin": + text = lazy_pinyin(chars_list) + sub_ids = [ + token_table[txt] if txt in token_table else oov_id for txt in text + ] + ids.append(sub_ids) + else: # token_type = "pinyin" + text = pinyin(chars_list) + sub_ids = [ + token_table[txt[0]] if txt[0] in token_table else oov_id + for txt in text + ] + ids.append(sub_ids) + return ids + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer + ) + line = f.readline() + while line: + x = line.split() + print(" ".join(x[: args.skip_ncols]), end=" ") + a = " ".join(x[args.skip_ncols :]) # noqa E203 + + a_flat = tokenize_by_CJK_char(a) + + # print("".join(a_chars)) + print(a_flat) + line = f.readline() + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index c09b9c1de..0704451f5 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -370,4 +370,72 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then done fi +if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then + log "Stage 16: Prepare LM data" + + ./prepare_lm_data.sh + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + + mkdir $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data ./data/lm_training_data/lm_training_text \ + --lm-archive $out_dir/lm_training_data.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data ./data/lm_dev_data/lm_dev_text \ + --lm-archive $out_dir/lm_dev_data.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data ./data/lm_test_data/lm_test_text \ + --lm-archive $out_dir/lm_test_data.pt + done +fi + +if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then + log "Stage 17: Sort LM data" + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_training_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_dev_data.pt \ + --out-lm-data $out_dir/sorted_lm_data-dev.pt \ + --out-statistics $out_dir/statistics-dev.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_test_data.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi + +if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then + log "Stage 18: Train RNN LM model" + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 2 \ + --use-fp16 0 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 3 \ + --batch-size 400 \ + --exp-dir rnnlm_bpe_${vocab_size}/exp \ + --lm-data $out_dir/sorted_lm_data.pt \ + --lm-data-valid $out_dir/sorted_lm_data-dev.pt \ + --vocab-size $vocab_size + done +fi diff --git a/egs/multi_zh-hans/ASR/prepare_lm_data.sh b/egs/multi_zh-hans/ASR/prepare_lm_data.sh new file mode 100644 index 000000000..0022deda4 --- /dev/null +++ b/egs/multi_zh-hans/ASR/prepare_lm_data.sh @@ -0,0 +1,229 @@ +cd data/ + +log "Preparing LM data..." +mkdir -p lm_training_data +mkdir -p lm_dev_data +mkdir -p lm_test_data + +log "aidatatang_200zh" +gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aidatatang_train_text + +gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/aidatatang_dev_text + +gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/aidatatang_test_text + +log "aishell" +gunzip -c manifests/aishell/aishell_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell_train_text + +gunzip -c manifests/aishell/aishell_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/aishell_dev_text + +gunzip -c manifests/aishell/aishell_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/aishell_test_text + +log "aishell2" +gunzip -c manifests/aishell2/aishell2_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell2_train_text + +gunzip -c manifests/aishell2/aishell2_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/aishell2_dev_text + +gunzip -c manifests/aishell2/aishell2_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/aishell2_test_text + +log "aishell4" +gunzip -c manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell4_train_L_text + +gunzip -c manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell4_train_M_text + +gunzip -c manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/aishell4_train_S_text + +gunzip -c manifests/aishell4/aishell4_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/aishell4_test_text + +log "alimeeting" +gunzip -c manifests/alimeeting/alimeeting-far_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/alimeeting-far_train_text + +gunzip -c manifests/alimeeting/alimeeting-far_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/alimeeting-far_test_text + +gunzip -c manifests/alimeeting/alimeeting-far_supervisions_eval.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/alimeeting-far_eval_text + +log "kespeech" +gunzip -c manifests/kespeech/kespeech-asr_supervisions_dev_phase1.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/kespeech_dev_phase1_text + +gunzip -c manifests/kespeech/kespeech-asr_supervisions_dev_phase2.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/kespeech_dev_phase2_text + +gunzip -c manifests/kespeech/kespeech-asr_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/kespeech_test_text + +gunzip -c manifests/kespeech/kespeech-asr_supervisions_train_phase1.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/kespeech_train_phase1_text + +gunzip -c manifests/kespeech/kespeech-asr_supervisions_train_phase2.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/kespeech_train_phase2_text + +log "magicdata" +gunzip -c manifests/magicdata/magicdata_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/magicdata_train_text + +gunzip -c manifests/magicdata/magicdata_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/magicdata_test_text + +gunzip -c manifests/magicdata/magicdata_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/magicdata_dev_text + +log "stcmds" +gunzip -c manifests/stcmds/stcmds_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/stcmds_train_text + +log "primewords" +gunzip -c manifests/primewords/primewords_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/primewords_train_text + +log "thchs30" +gunzip -c manifests/thchs30/thchs_30_supervisions_train.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/thchs30_train_text + +gunzip -c manifests/thchs30/thchs_30_supervisions_test.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/thchs30_test_text + +gunzip -c manifests/thchs30/thchs_30_supervisions_dev.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/thchs30_dev_text + +log "wenetspeech" +gunzip -c manifests/wenetspeech/wenetspeech_supervisions_L.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_training_data/wenetspeech_L_text + +gunzip -c manifests/wenetspeech/wenetspeech_supervisions_DEV.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_dev_data/wenetspeech_DEV_text + +gunzip -c manifests/wenetspeech/wenetspeech_supervisions_TEST_MEETING.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/wenetspeech_TEST_MEETING_text + +gunzip -c manifests/wenetspeech/wenetspeech_supervisions_TEST_NET.jsonl.gz \ + | jq '.text' \ + | sed 's/"//g' \ + | ../local/tokenize_for_lm_training.py -t "char" \ + > lm_test_data/wenetspeech_TEST_NET_text + +for f in aidatatang_train_text aishell2_train_text aishell4_train_L_text aishell4_train_M_text aishell4_train_S_text aishell_train_text alimeeting-far_train_text kespeech_train_phase1_text kespeech_train_phase2_text magicdata_train_text primewords_train_text stcmds_train_text thchs30_train_text wenetspeech_L_text; do + cat lm_training_data/$f >> lm_training_data/lm_training_text +done + +for f in aidatatang_test_text aishell4_test_text alimeeting-far_test_text thchs30_test_text wenetspeech_TEST_NET_text aishell2_test_text aishell_test_text kespeech_test_text magicdata_test_text wenetspeech_TEST_MEETING_text; do + cat lm_test_data/$f >> lm_test_data/lm_test_text +done + +for f in aidatatang_dev_text aishell_dev_text kespeech_dev_phase1_text thchs30_dev_text aishell2_dev_text alimeeting-far_eval_text kespeech_dev_phase2_text magicdata_dev_text wenetspeech_DEV_text; do + cat lm_dev_data/$f >> lm_dev_data/lm_dev_text +done + +cd ../ diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 341579acb..bc61a56c7 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -19,13 +19,12 @@ import argparse import inspect import logging -from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( # noqa PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, @@ -34,10 +33,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures SimpleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index f501c3c30..89e3dfa98 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -97,6 +97,7 @@ Usage: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -115,11 +116,16 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from lhotse.cut import Cut from multi_dataset import MultiDataset from train import add_model_arguments, get_model, get_params +from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -212,6 +218,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - modified_beam_search_LODR - fast_beam_search - fast_beam_search_nbest - fast_beam_search_nbest_oracle @@ -303,6 +310,81 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) return parser @@ -315,6 +397,10 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -343,6 +429,12 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -380,6 +472,7 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -394,6 +487,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -408,6 +502,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -423,6 +518,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -431,6 +527,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -440,9 +537,60 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + blank_penalty=params.blank_penalty, + ) else: batch_size = encoder_out.size(0) @@ -455,12 +603,14 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + blank_penalty=params.blank_penalty, ) else: raise ValueError( @@ -481,6 +631,22 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -492,6 +658,10 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -540,8 +710,12 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + context_graph=context_graph, word_table=word_table, batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -610,6 +784,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -624,9 +799,18 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", ) params.res_dir = params.exp_dir / params.decoding_method + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: @@ -653,10 +837,24 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -762,6 +960,54 @@ def main(): model.to(device) model.eval() + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -779,6 +1025,18 @@ def main(): decoding_graph = None word_table = None + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -813,6 +1071,10 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) save_results( diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py index b1920e62e..ac90dc73e 100644 --- a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -15,11 +15,9 @@ # limitations under the License. -import glob import logging -import re from pathlib import Path -from typing import Dict, List +from typing import Dict import lhotse from lhotse import CutSet, load_manifest_lazy diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py index bed3856e4..667fe600c 100755 --- a/egs/ptb/LM/local/sort_lm_training_data.py +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -31,6 +31,7 @@ from pathlib import Path import k2 import numpy as np import torch +from tqdm.auto import tqdm def get_args(): @@ -87,7 +88,7 @@ def main(): ) cur = None - for i in range(num_sentences): + for i in tqdm(range(num_sentences)): word_ids = sorted_sentences[i] token_ids = words2bpe[word_ids] if isinstance(token_ids, k2.RaggedTensor):