Merge c2cb70fc22ffd0a9cb8cbe107846ef3441a7d39c into d9ae8c02a0abdeddc5a4cf9fad72293eda134de3

This commit is contained in:
zr_jin 2024-02-10 04:49:39 -07:00 committed by GitHub
commit 4d047dc8b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 976 additions and 13 deletions

View File

@ -0,0 +1,104 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
"""
Convert a transcript file containing words to a corpus file containing tokens
for LM training with the help of a lexicon.
If the lexicon contains phones, the resulting LM will be a phone LM; If the
lexicon contains word pieces, the resulting LM will be a word piece LM.
If a word has multiple pronunciations, the one that appears first in the lexicon
is kept; others are removed.
If the input transcript is:
hello zoo world hello
world zoo
foo zoo world hellO
and if the lexicon is
<UNK> SPN
hello h e l l o 2
hello h e l l o
world w o r l d
zoo z o o
Then the output is
h e l l o 2 z o o w o r l d h e l l o 2
w o r l d z o o
SPN z o o w o r l d SPN
"""
import argparse
from pathlib import Path
from typing import Dict, List
from generate_unique_lexicon import filter_multiple_pronunications
from icefall.lexicon import read_lexicon
from icefall.utils import tokenize_by_CJK_char
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transcript",
type=str,
help="The input transcript file."
"We assume that the transcript file consists of "
"lines. Each line consists of space separated words.",
)
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
parser.add_argument("--oov", type=str, default="<UNK>", help="The OOV word.")
return parser.parse_args()
def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None:
"""
Args:
lexicon:
A dict containing pronunciations. Its keys are words and values
are pronunciations (i.e., tokens).
line:
A line of transcript consisting of space(s) separated words.
oov_token:
The pronunciation of the oov word if a word in `line` is not present
in the lexicon.
Returns:
Return None.
"""
s = ""
words = tokenize_by_CJK_char(line).strip().split()
for i, w in enumerate(words):
tokens = lexicon.get(w, oov_token)
s += " ".join(tokens)
s += " "
print(s.strip())
def main():
args = get_args()
assert Path(args.lexicon).is_file()
assert Path(args.transcript).is_file()
assert len(args.oov) > 0
# Only the first pronunciation of a word is kept
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
lexicon = dict(lexicon)
assert args.oov in lexicon
oov_token = lexicon[args.oov]
with open(args.transcript) as f:
for line in f:
process_line(lexicon=lexicon, line=line, oov_token=oov_token)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/generate_unique_lexicon.py

View File

@ -22,8 +22,6 @@
import argparse
from pathlib import Path
from tqdm.auto import tqdm
from icefall.utils import tokenize_by_CJK_char

View File

@ -0,0 +1,160 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
# Fangjun Kuang,
# Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes a `bpe.model` and a text file such as
./download/lm/librispeech-lm-norm.txt
and outputs the LM training data to a supplied directory such
as data/lm_training_bpe_500. The format is as follows:
It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
representation of a dict with the following format:
'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32
containing the BPE representations of each word, indexed by
integer word ID. (These integer word IDS are present in
'lm_data'). The sentencepiece object can be used to turn the
words and BPE units into string form.
'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype
torch.int32 containing all the sentences, as word-ids (we don't
output the string form of this directly but it can be worked out
together with 'words' and the bpe.model).
'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing
number of BPE tokens of each sentence.
"""
import argparse
import logging
from pathlib import Path
import k2
import sentencepiece as spm
import torch
from icefall.utils import tokenize_by_CJK_char
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
default="data/lang_bpe_2000/bpe.model",
type=str,
help="Input BPE model, e.g. data/bpe_500/bpe.model",
)
parser.add_argument(
"--lm-data",
type=str,
help="""Input LM training data as text, e.g.
download/pb.train.txt""",
)
parser.add_argument(
"--lm-archive",
type=str,
help="""Path to output archive, e.g. data/bpe_500/lm_data.pt;
look at the source of this script to see the format.""",
)
return parser.parse_args()
def main():
args = get_args()
if Path(args.lm_archive).exists():
logging.warning(f"{args.lm_archive} exists - skipping")
return
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
# word2index is a dictionary from words to integer ids. No need to reserve
# space for epsilon, etc.; the words are just used as a convenient way to
# compress the sequences of BPE pieces.
word2index = dict()
word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces.
sentences = [] # Will be a list-of-list-of-int, representing word-ids.
step = 500000
processed = 0
with open(args.lm_data) as f:
while True:
line = f.readline()
if line == "":
break
line = tokenize_by_CJK_char(line)
if line == "":
continue
if step and processed % step == 0:
logging.info(f"Processed number of lines: {processed} ")
processed += 1
line_words = line.split()
for w in line_words:
if w not in word2index:
w_bpe = sp.encode(w)
word2index[w] = len(word2bpe)
word2bpe.append(w_bpe)
sentences.append([word2index[w] for w in line_words])
logging.info("Constructing ragged tensors")
words = k2.ragged.RaggedTensor(word2bpe)
sentences = k2.ragged.RaggedTensor(sentences)
output = dict(words=words, sentences=sentences)
num_sentences = sentences.dim0
logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}")
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
if step and i % step == 0:
logging.info(
f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)"
)
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32)
torch.save(output, args.lm_archive)
logging.info(f"Saved to {args.lm_archive}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/sort_lm_training_data.py

View File

@ -0,0 +1,145 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe)
# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import codecs
import re
import sys
from typing import List
from pypinyin import lazy_pinyin, pinyin
from icefall.utils import str2bool, tokenize_by_CJK_char
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "pinyin", "lazy_pinyin"],
help="""Transcript type. char/pinyin/lazy_pinyin""",
)
return parser
def token2id(
texts, token_table, token_type: str = "lazy_pinyin", oov: str = "<unk>"
) -> List[List[int]]:
"""Convert token to id.
Args:
texts:
The input texts, it refers to the chinese text here.
token_table:
The token table is built based on "data/lang_xxx/token.txt"
token_type:
The type of token, such as "pinyin" and "lazy_pinyin".
oov:
Out of vocabulary token. When a word(token) in the transcript
does not exist in the token list, it is replaced with `oov`.
Returns:
The list of ids for the input texts.
"""
if texts is None:
raise ValueError("texts can't be None!")
else:
oov_id = token_table[oov]
ids: List[List[int]] = []
for text in texts:
chars_list = list(str(text))
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
text = pinyin(chars_list)
sub_ids = [
token_table[txt[0]] if txt[0] in token_table else oov_id
for txt in text
]
ids.append(sub_ids)
return ids
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
line = f.readline()
while line:
x = line.split()
print(" ".join(x[: args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols :]) # noqa E203
a_flat = tokenize_by_CJK_char(a)
# print("".join(a_chars))
print(a_flat)
line = f.readline()
if __name__ == "__main__":
main()

View File

@ -370,4 +370,72 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
done
fi
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
log "Stage 16: Prepare LM data"
./prepare_lm_data.sh
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
out_dir=data/lm_training_bpe_${vocab_size}
mkdir $out_dir
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data ./data/lm_training_data/lm_training_text \
--lm-archive $out_dir/lm_training_data.pt
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data ./data/lm_dev_data/lm_dev_text \
--lm-archive $out_dir/lm_dev_data.pt
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data ./data/lm_test_data/lm_test_text \
--lm-archive $out_dir/lm_test_data.pt
done
fi
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
log "Stage 17: Sort LM data"
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/lm_training_bpe_${vocab_size}
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_training_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_dev_data.pt \
--out-lm-data $out_dir/sorted_lm_data-dev.pt \
--out-statistics $out_dir/statistics-dev.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_test_data.pt \
--out-lm-data $out_dir/sorted_lm_data-test.pt \
--out-statistics $out_dir/statistics-test.txt
done
fi
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
log "Stage 18: Train RNN LM model"
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/lm_training_bpe_${vocab_size}
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \
--world-size 2 \
--use-fp16 0 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 3 \
--batch-size 400 \
--exp-dir rnnlm_bpe_${vocab_size}/exp \
--lm-data $out_dir/sorted_lm_data.pt \
--lm-data-valid $out_dir/sorted_lm_data-dev.pt \
--vocab-size $vocab_size
done
fi

View File

@ -0,0 +1,229 @@
cd data/
log "Preparing LM data..."
mkdir -p lm_training_data
mkdir -p lm_dev_data
mkdir -p lm_test_data
log "aidatatang_200zh"
gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/aidatatang_train_text
gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_dev.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/aidatatang_dev_text
gunzip -c manifests/aidatatang_200zh/aidatatang_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/aidatatang_test_text
log "aishell"
gunzip -c manifests/aishell/aishell_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/aishell_train_text
gunzip -c manifests/aishell/aishell_supervisions_dev.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/aishell_dev_text
gunzip -c manifests/aishell/aishell_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/aishell_test_text
log "aishell2"
gunzip -c manifests/aishell2/aishell2_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/aishell2_train_text
gunzip -c manifests/aishell2/aishell2_supervisions_dev.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/aishell2_dev_text
gunzip -c manifests/aishell2/aishell2_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/aishell2_test_text
log "aishell4"
gunzip -c manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/aishell4_train_L_text
gunzip -c manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/aishell4_train_M_text
gunzip -c manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/aishell4_train_S_text
gunzip -c manifests/aishell4/aishell4_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/aishell4_test_text
log "alimeeting"
gunzip -c manifests/alimeeting/alimeeting-far_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/alimeeting-far_train_text
gunzip -c manifests/alimeeting/alimeeting-far_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/alimeeting-far_test_text
gunzip -c manifests/alimeeting/alimeeting-far_supervisions_eval.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/alimeeting-far_eval_text
log "kespeech"
gunzip -c manifests/kespeech/kespeech-asr_supervisions_dev_phase1.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/kespeech_dev_phase1_text
gunzip -c manifests/kespeech/kespeech-asr_supervisions_dev_phase2.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/kespeech_dev_phase2_text
gunzip -c manifests/kespeech/kespeech-asr_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/kespeech_test_text
gunzip -c manifests/kespeech/kespeech-asr_supervisions_train_phase1.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/kespeech_train_phase1_text
gunzip -c manifests/kespeech/kespeech-asr_supervisions_train_phase2.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/kespeech_train_phase2_text
log "magicdata"
gunzip -c manifests/magicdata/magicdata_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/magicdata_train_text
gunzip -c manifests/magicdata/magicdata_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/magicdata_test_text
gunzip -c manifests/magicdata/magicdata_supervisions_dev.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/magicdata_dev_text
log "stcmds"
gunzip -c manifests/stcmds/stcmds_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/stcmds_train_text
log "primewords"
gunzip -c manifests/primewords/primewords_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/primewords_train_text
log "thchs30"
gunzip -c manifests/thchs30/thchs_30_supervisions_train.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/thchs30_train_text
gunzip -c manifests/thchs30/thchs_30_supervisions_test.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/thchs30_test_text
gunzip -c manifests/thchs30/thchs_30_supervisions_dev.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/thchs30_dev_text
log "wenetspeech"
gunzip -c manifests/wenetspeech/wenetspeech_supervisions_L.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_training_data/wenetspeech_L_text
gunzip -c manifests/wenetspeech/wenetspeech_supervisions_DEV.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_dev_data/wenetspeech_DEV_text
gunzip -c manifests/wenetspeech/wenetspeech_supervisions_TEST_MEETING.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/wenetspeech_TEST_MEETING_text
gunzip -c manifests/wenetspeech/wenetspeech_supervisions_TEST_NET.jsonl.gz \
| jq '.text' \
| sed 's/"//g' \
| ../local/tokenize_for_lm_training.py -t "char" \
> lm_test_data/wenetspeech_TEST_NET_text
for f in aidatatang_train_text aishell2_train_text aishell4_train_L_text aishell4_train_M_text aishell4_train_S_text aishell_train_text alimeeting-far_train_text kespeech_train_phase1_text kespeech_train_phase2_text magicdata_train_text primewords_train_text stcmds_train_text thchs30_train_text wenetspeech_L_text; do
cat lm_training_data/$f >> lm_training_data/lm_training_text
done
for f in aidatatang_test_text aishell4_test_text alimeeting-far_test_text thchs30_test_text wenetspeech_TEST_NET_text aishell2_test_text aishell_test_text kespeech_test_text magicdata_test_text wenetspeech_TEST_MEETING_text; do
cat lm_test_data/$f >> lm_test_data/lm_test_text
done
for f in aidatatang_dev_text aishell_dev_text kespeech_dev_phase1_text thchs30_dev_text aishell2_dev_text alimeeting-far_eval_text kespeech_dev_phase2_text magicdata_dev_text wenetspeech_DEV_text; do
cat lm_dev_data/$f >> lm_dev_data/lm_dev_text
done
cd ../

View File

@ -19,13 +19,12 @@
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( # noqa PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
@ -34,10 +33,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader

View File

@ -97,6 +97,7 @@ Usage:
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -115,11 +116,16 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_rescore,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -212,6 +218,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- modified_beam_search_LODR
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
@ -303,6 +310,81 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=2,
help="""The order of the ngram lm.
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
add_model_arguments(parser)
return parser
@ -315,6 +397,10 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -343,6 +429,12 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network language model.
ngram_lm:
A ngram language model
ngram_lm_scale:
The scale for the ngram language model.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -380,6 +472,7 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -394,6 +487,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
@ -408,6 +502,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -423,6 +518,7 @@ def decode_one_batch(
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -431,6 +527,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -440,9 +537,60 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
context_graph=context_graph,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
lm_scale_list=lm_scale_list,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)]
ans_dict = modified_beam_search_lm_rescore_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
blank_penalty=params.blank_penalty,
)
else:
batch_size = encoder_out.size(0)
@ -455,12 +603,14 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(
@ -481,6 +631,22 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}"
if params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
return {prefix: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -492,6 +658,10 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -540,8 +710,12 @@ def decode_dataset(
model=model,
sp=sp,
decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table,
batch=batch,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
for name, hyps in hyps_dict.items():
@ -610,6 +784,7 @@ def save_results(
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -624,9 +799,18 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
)
params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
@ -653,10 +837,24 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_shallow_fusion:
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -762,6 +960,54 @@ def main():
model.to(device)
model.eval()
# only load the neural network LM if required
if params.use_shallow_fusion or params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
):
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
# only load N-gram LM when needed
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
try:
import kenlm
except ImportError:
print("Please install kenlm first. You can use")
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
print("to install it")
import sys
sys.exit(-1)
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
logging.info(f"lm filename: {ngram_file_name}")
ngram_lm = kenlm.Model(ngram_file_name)
ngram_lm_scale = None # use a list to search
elif params.decoding_method == "modified_beam_search_LODR":
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"Loading token level lm: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
@ -779,6 +1025,18 @@ def main():
decoding_graph = None
word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -813,6 +1071,10 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
save_results(

View File

@ -15,11 +15,9 @@
# limitations under the License.
import glob
import logging
import re
from pathlib import Path
from typing import Dict, List
from typing import Dict
import lhotse
from lhotse import CutSet, load_manifest_lazy

View File

@ -31,6 +31,7 @@ from pathlib import Path
import k2
import numpy as np
import torch
from tqdm.auto import tqdm
def get_args():
@ -87,7 +88,7 @@ def main():
)
cur = None
for i in range(num_sentences):
for i in tqdm(range(num_sentences)):
word_ids = sorted_sentences[i]
token_ids = words2bpe[word_ids]
if isinstance(token_ids, k2.RaggedTensor):