mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Training with byte level BPE (AIShell) (#986)
* copy files from zipformer librispeech * Add byte bpe training for aishell * compile LG graph * Support LG decoding * Minor fixes * black * Minor fixes * export & fix pretrain.py * fix black * Update RESULTS.md * Fix export.py
This commit is contained in:
parent
61ec3a7a8f
commit
80156dda09
@ -2,6 +2,57 @@
|
||||
|
||||
### Aishell training result(Stateless Transducer)
|
||||
|
||||
#### Pruned transducer stateless 7 (zipformer)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/986>
|
||||
|
||||
[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe)
|
||||
|
||||
**Note**: The modeling units are byte level BPEs
|
||||
|
||||
The best results I have gotten are:
|
||||
|
||||
Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments
|
||||
-- | -- | -- | -- | -- | --
|
||||
500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29
|
||||
|
||||
The training command:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="4,5,6,7"
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 800 \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--exp-dir pruned_transducer_stateless7_bbpe/exp \
|
||||
--lr-epochs 6 \
|
||||
--master-port 12535
|
||||
```
|
||||
|
||||
The decoding command:
|
||||
|
||||
```
|
||||
for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 48 \
|
||||
--avg 29 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-sym-per-frame 1 \
|
||||
--ngram-lm-scale 0.25 \
|
||||
--ilme-scale 0.2 \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--max-duration 2000 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
|
||||
|
||||
#### Pruned transducer stateless 3
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/436>
|
||||
@ -75,7 +126,7 @@ for epoch in 29; do
|
||||
done
|
||||
```
|
||||
|
||||
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
|
||||
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
|
||||
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
|
||||
please use the following command:
|
||||
|
||||
|
1
egs/aishell/ASR/local/compile_lg.py
Symbolic link
1
egs/aishell/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`:
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
@ -189,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
|
||||
return tokens
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
lang_dir = Path("data/lang_char")
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
text_file = lang_dir / "text"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
267
egs/aishell/ASR/local/prepare_lang_bbpe.py
Executable file
267
egs/aishell/ASR/local/prepare_lang_bbpe.py
Executable file
@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
This script takes as input `lang_dir`, which should contain::
|
||||
|
||||
- lang_dir/bbpe.model,
|
||||
- lang_dir/words.txt
|
||||
|
||||
and generates the following files in the directory `lang_dir`:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
from icefall.byte_utils import byte_encode
|
||||
from icefall.utils import str2bool, tokenize_by_CJK_char
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||
|
||||
arcs = []
|
||||
|
||||
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||
assert token2id["<blk>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, pieces in lexicon:
|
||||
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [token2id[i] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last piece of this word
|
||||
i = len(pieces) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
model_file: str, words: List[str], oov: str
|
||||
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||
"""Generate a lexicon from a BPE model.
|
||||
|
||||
Args:
|
||||
model_file:
|
||||
Path to a sentencepiece model.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
oov:
|
||||
The out of vocabulary word in lexicon.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
- A dict whose keys are words and values are the corresponding
|
||||
word pieces.
|
||||
- A dict representing the token symbol, mapping from tokens to IDs.
|
||||
"""
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(model_file))
|
||||
|
||||
# Convert word to word piece IDs instead of word piece strings
|
||||
# to avoid OOV tokens.
|
||||
encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words]
|
||||
words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int)
|
||||
|
||||
# Now convert word piece IDs back to word piece strings.
|
||||
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
|
||||
|
||||
lexicon = []
|
||||
for word, pieces in zip(words, words_pieces):
|
||||
lexicon.append((word, pieces))
|
||||
|
||||
lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
|
||||
|
||||
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
|
||||
|
||||
return lexicon, token2id
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--oov",
|
||||
type=str,
|
||||
default="<UNK>",
|
||||
help="The out of vocabulary word in lexicon.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
|
||||
See "test/test_bpe_lexicon.py" for usage.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
model_file = lang_dir / "bbpe.model"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
|
||||
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
next_token_id = max(token_sym_table.values()) + 1
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table[disambig] = next_token_id
|
||||
next_token_id += 1
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
113
egs/aishell/ASR/local/train_bbpe_model.py
Executable file
113
egs/aishell/ASR/local/train_bbpe_model.py
Executable file
@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
#
|
||||
# Due to an issue reported in
|
||||
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
from icefall import byte_encode, tokenize_by_CJK_char
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
help="Vocabulary size for BPE training",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _convert_to_bchar(in_path: str, out_path: str):
|
||||
with open(out_path, "w") as f:
|
||||
for line in open(in_path, "r").readlines():
|
||||
f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
vocab_size = args.vocab_size
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
unk_id = len(user_defined_symbols)
|
||||
# Note: unk_id is fixed to 2.
|
||||
# If you change it, you should also change other
|
||||
# places that are using it.
|
||||
|
||||
temp = tempfile.NamedTemporaryFile()
|
||||
train_text = temp.name
|
||||
|
||||
_convert_to_bchar(args.transcript, train_text)
|
||||
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if not model_file.is_file():
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=train_text,
|
||||
vocab_size=vocab_size,
|
||||
model_type=model_type,
|
||||
model_prefix=model_prefix,
|
||||
input_sentence_size=input_sentence_size,
|
||||
character_coverage=character_coverage,
|
||||
user_defined_symbols=user_defined_symbols,
|
||||
unk_id=unk_id,
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
)
|
||||
else:
|
||||
print(f"{model_file} exists - skipping")
|
||||
return
|
||||
|
||||
shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -35,6 +35,15 @@ dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bbpe_xxx,
|
||||
# data/lang_bbpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
@ -47,20 +56,6 @@ log() {
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: Download LM"
|
||||
# We assume that you have installed the git-lfs, if not, you could install it
|
||||
# using: `sudo apt-get install git-lfs && git-lfs install`
|
||||
git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
|
||||
|
||||
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
|
||||
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
|
||||
pushd $dl_dir/lm
|
||||
git lfs pull --include "3-gram.unpruned.arpa"
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: Download data"
|
||||
|
||||
@ -134,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
lang_phone_dir=data/lang_phone
|
||||
lang_char_dir=data/lang_char
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
mkdir -p $lang_phone_dir
|
||||
@ -183,45 +177,107 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
lang_char_dir=data/lang_char
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
mkdir -p $lang_char_dir
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp $lang_phone_dir/words.txt $lang_char_dir
|
||||
|
||||
# The transcripts in training set, generated in stage 5
|
||||
cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
|
||||
|
||||
cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
|
||||
cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text
|
||||
cut -d " " -f 2- > $lang_char_dir/text
|
||||
|
||||
(echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
|
||||
> $lang_char_dir/words.txt
|
||||
|
||||
cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
|
||||
| awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
|
||||
|
||||
num_lines=$(< $lang_char_dir/words.txt wc -l)
|
||||
(echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
|
||||
>> $lang_char_dir/words.txt
|
||||
|
||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||
./local/prepare_char.py
|
||||
./local/prepare_char.py --lang-dir $lang_char_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
log "Stage 7: Prepare Byte BPE based lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
cp $lang_char_dir/words.txt $lang_dir
|
||||
cp $lang_char_dir/text $lang_dir
|
||||
|
||||
if [ ! -f $lang_dir/bbpe.model ]; then
|
||||
./local/train_bbpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/text
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bbpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Prepare G"
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
|
||||
# Train LM on transcripts
|
||||
if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
|
||||
python3 ./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text $lang_char_dir/transcript_words.txt \
|
||||
-lm data/lm/3-gram.unpruned.arpa
|
||||
fi
|
||||
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_phone_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
$dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
|
||||
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt
|
||||
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_char_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Compile HLG"
|
||||
./local/compile_hlg.py --lang-dir $lang_phone_dir
|
||||
./local/compile_hlg.py --lang-dir $lang_char_dir
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Compile LG & HLG"
|
||||
./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
|
||||
./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
|
||||
done
|
||||
|
||||
./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
|
||||
./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Generate LM training data"
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Generate LM training data"
|
||||
|
||||
log "Processing char based data"
|
||||
out_dir=data/lm_training_char
|
||||
@ -267,8 +323,8 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Sort LM training data"
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Sort LM training data"
|
||||
# Sort LM training data by sentence length in descending order
|
||||
# for ease of training.
|
||||
#
|
||||
@ -295,7 +351,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
--out-statistics $out_dir/statistics-test.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 11: Train RNN LM model"
|
||||
python ../../../icefall/rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/asr_datamodule.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
819
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
Executable file
819
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
Executable file
@ -0,0 +1,819 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import (
|
||||
LmScorer,
|
||||
NgramLm,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
tokenize_by_CJK_char,
|
||||
)
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bbpe_500/bbpe.model",
|
||||
help="Path to the byte BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bbpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
If you use fast_beam_search_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ilme-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for the internal language model estimation.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "fast_beam_search_LG":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
subtract_ilme=True,
|
||||
ilme_scale=params.ilme_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
ref_texts = []
|
||||
for tx in supervisions["text"]:
|
||||
ref_texts.append(byte_encode(tokenize_by_CJK_char(tx)))
|
||||
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(ref_texts),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
key += f"_ilme_scale_{params.ilme_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
params.suffix += f"-ilme-scale-{params.ilme_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bbpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
|
||||
test_cuts = aishell.test_cuts()
|
||||
dev_cuts = aishell.valid_cuts()
|
||||
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
dev_dl = aishell.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test", "dev"]
|
||||
test_dls = [test_dl, dev_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/encoder_interface.py
|
320
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py
Executable file
320
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py
Executable file
@ -0,0 +1,320 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/aishell/ASR
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
# You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_bbpe/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bbpe_500/bbpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bbpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
274
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
Executable file
274
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
Executable file
@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 49 \
|
||||
--avg 28 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/jit_pretrained.py \
|
||||
--nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall import smart_byte_decode
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = smart_byte_decode(sp.decode(hyp))
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/model.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
|
345
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
Executable file
345
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
Executable file
@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 48 \
|
||||
--avg 29
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) fast beam search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7_bbpe/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import smart_byte_decode
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
|
1261
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
Executable file
1261
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
@ -21,7 +21,7 @@ import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
@ -181,7 +181,16 @@ class AishellAsrDataModule:
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
@ -277,6 +286,10 @@ class AishellAsrDataModule:
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
@ -325,7 +338,7 @@ class AishellAsrDataModule:
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
logging.info("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
|
@ -45,11 +45,18 @@ def get_args():
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
@ -61,15 +68,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
logging.info("Loading pre-compiled G_3_gram")
|
||||
d = torch.load("data/lm/G_3_gram.pt")
|
||||
if Path(f"data/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"data/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"data/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||
torch.save(G.as_dict(), f"data/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
@ -126,7 +133,7 @@ def main():
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir)
|
||||
LG = compile_LG(lang_dir, args.lm)
|
||||
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||
|
||||
|
@ -47,6 +47,8 @@ def fast_beam_search_one_best(
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
temperature: float = 1.0,
|
||||
subtract_ilme: bool = False,
|
||||
ilme_scale: float = 0.1,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
@ -88,6 +90,8 @@ def fast_beam_search_one_best(
|
||||
max_states=max_states,
|
||||
max_contexts=max_contexts,
|
||||
temperature=temperature,
|
||||
subtract_ilme=subtract_ilme,
|
||||
ilme_scale=ilme_scale,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
@ -428,6 +432,8 @@ def fast_beam_search(
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
temperature: float = 1.0,
|
||||
subtract_ilme: bool = False,
|
||||
ilme_scale: float = 0.1,
|
||||
) -> k2.Fsa:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -498,6 +504,17 @@ def fast_beam_search(
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = (logits / temperature).log_softmax(dim=-1)
|
||||
if subtract_ilme:
|
||||
ilme_logits = model.joiner(
|
||||
torch.zeros_like(
|
||||
current_encoder_out, device=current_encoder_out.device
|
||||
).unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
|
||||
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
|
||||
log_probs -= ilme_scale * ilme_log_probs
|
||||
decoding_streams.advance(log_probs)
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||
|
@ -8,6 +8,12 @@ from . import (
|
||||
utils
|
||||
)
|
||||
|
||||
from .byte_utils import (
|
||||
byte_decode,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
)
|
||||
|
||||
from .checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
@ -49,6 +55,7 @@ from .utils import (
|
||||
get_alignments,
|
||||
get_executor,
|
||||
get_texts,
|
||||
is_cjk,
|
||||
is_jit_tracing,
|
||||
is_module_available,
|
||||
l1_norm,
|
||||
@ -64,6 +71,7 @@ from .utils import (
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
subsequent_chunk_mask,
|
||||
tokenize_by_CJK_char,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
311
icefall/byte_utils.py
Normal file
311
icefall/byte_utils.py
Normal file
@ -0,0 +1,311 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
|
||||
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
|
||||
PRINTABLE_BASE_CHARS = [
|
||||
256,
|
||||
257,
|
||||
258,
|
||||
259,
|
||||
260,
|
||||
261,
|
||||
262,
|
||||
263,
|
||||
264,
|
||||
265,
|
||||
266,
|
||||
267,
|
||||
268,
|
||||
269,
|
||||
270,
|
||||
271,
|
||||
272,
|
||||
273,
|
||||
274,
|
||||
275,
|
||||
276,
|
||||
277,
|
||||
278,
|
||||
279,
|
||||
280,
|
||||
281,
|
||||
282,
|
||||
283,
|
||||
284,
|
||||
285,
|
||||
286,
|
||||
287,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
62,
|
||||
63,
|
||||
64,
|
||||
65,
|
||||
66,
|
||||
67,
|
||||
68,
|
||||
69,
|
||||
70,
|
||||
71,
|
||||
72,
|
||||
73,
|
||||
74,
|
||||
75,
|
||||
76,
|
||||
77,
|
||||
78,
|
||||
79,
|
||||
80,
|
||||
81,
|
||||
82,
|
||||
83,
|
||||
84,
|
||||
85,
|
||||
86,
|
||||
87,
|
||||
88,
|
||||
89,
|
||||
90,
|
||||
91,
|
||||
92,
|
||||
93,
|
||||
94,
|
||||
95,
|
||||
96,
|
||||
97,
|
||||
98,
|
||||
99,
|
||||
100,
|
||||
101,
|
||||
102,
|
||||
103,
|
||||
104,
|
||||
105,
|
||||
106,
|
||||
107,
|
||||
108,
|
||||
109,
|
||||
110,
|
||||
111,
|
||||
112,
|
||||
113,
|
||||
114,
|
||||
115,
|
||||
116,
|
||||
117,
|
||||
118,
|
||||
119,
|
||||
120,
|
||||
121,
|
||||
122,
|
||||
123,
|
||||
124,
|
||||
125,
|
||||
126,
|
||||
288,
|
||||
289,
|
||||
290,
|
||||
291,
|
||||
292,
|
||||
293,
|
||||
294,
|
||||
295,
|
||||
296,
|
||||
297,
|
||||
298,
|
||||
299,
|
||||
300,
|
||||
301,
|
||||
302,
|
||||
303,
|
||||
304,
|
||||
305,
|
||||
308,
|
||||
309,
|
||||
310,
|
||||
311,
|
||||
312,
|
||||
313,
|
||||
314,
|
||||
315,
|
||||
316,
|
||||
317,
|
||||
318,
|
||||
321,
|
||||
322,
|
||||
323,
|
||||
324,
|
||||
325,
|
||||
326,
|
||||
327,
|
||||
328,
|
||||
330,
|
||||
331,
|
||||
332,
|
||||
333,
|
||||
334,
|
||||
335,
|
||||
336,
|
||||
337,
|
||||
338,
|
||||
339,
|
||||
340,
|
||||
341,
|
||||
342,
|
||||
343,
|
||||
344,
|
||||
345,
|
||||
346,
|
||||
347,
|
||||
348,
|
||||
349,
|
||||
350,
|
||||
351,
|
||||
352,
|
||||
353,
|
||||
354,
|
||||
355,
|
||||
356,
|
||||
357,
|
||||
358,
|
||||
359,
|
||||
360,
|
||||
361,
|
||||
362,
|
||||
363,
|
||||
364,
|
||||
365,
|
||||
366,
|
||||
367,
|
||||
368,
|
||||
369,
|
||||
370,
|
||||
371,
|
||||
372,
|
||||
373,
|
||||
374,
|
||||
375,
|
||||
376,
|
||||
377,
|
||||
378,
|
||||
379,
|
||||
380,
|
||||
381,
|
||||
382,
|
||||
384,
|
||||
385,
|
||||
386,
|
||||
387,
|
||||
388,
|
||||
389,
|
||||
390,
|
||||
391,
|
||||
392,
|
||||
393,
|
||||
394,
|
||||
395,
|
||||
396,
|
||||
397,
|
||||
398,
|
||||
399,
|
||||
400,
|
||||
401,
|
||||
402,
|
||||
403,
|
||||
404,
|
||||
405,
|
||||
406,
|
||||
407,
|
||||
408,
|
||||
409,
|
||||
410,
|
||||
411,
|
||||
412,
|
||||
413,
|
||||
414,
|
||||
415,
|
||||
416,
|
||||
417,
|
||||
418,
|
||||
419,
|
||||
420,
|
||||
421,
|
||||
422,
|
||||
]
|
||||
|
||||
for c in PRINTABLE_BASE_CHARS:
|
||||
assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c
|
||||
|
||||
BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
|
||||
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
|
||||
|
||||
|
||||
def byte_encode(x: str) -> str:
|
||||
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
|
||||
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
|
||||
|
||||
|
||||
def byte_decode(x: str) -> str:
|
||||
try:
|
||||
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
def smart_byte_decode(x: str) -> str:
|
||||
output = byte_decode(x)
|
||||
if output == "":
|
||||
# DP the best recovery (max valid chars) if it's broken
|
||||
n_bytes = len(x)
|
||||
f = [0 for _ in range(n_bytes + 1)]
|
||||
pt = [0 for _ in range(n_bytes + 1)]
|
||||
for i in range(1, n_bytes + 1):
|
||||
f[i], pt[i] = f[i - 1], i - 1
|
||||
for j in range(1, min(4, i) + 1):
|
||||
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
|
||||
f[i], pt[i] = f[i - j] + 1, i - j
|
||||
cur_pt = n_bytes
|
||||
while cur_pt > 0:
|
||||
if f[cur_pt] == f[pt[cur_pt]] + 1:
|
||||
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
|
||||
cur_pt = pt[cur_pt]
|
||||
return output
|
@ -87,7 +87,7 @@ class RnnLmModel(torch.nn.Module):
|
||||
h0: torch.Tensor,
|
||||
c0: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
'''
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, L). We won't prepend it with SOS.
|
||||
@ -103,7 +103,7 @@ class RnnLmModel(torch.nn.Module):
|
||||
- negative loglike (nll), a 1-D tensor of shape (N,)
|
||||
- next_h0, a 3-D tensor with the same shape as h0
|
||||
- next_c0, a 3-D tensor with the same shape as c0
|
||||
'''
|
||||
"""
|
||||
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
|
||||
assert x.shape == y.shape, (x.shape, y.shape)
|
||||
|
||||
|
@ -1306,6 +1306,31 @@ def tokenize_by_bpe_model(
|
||||
return txt_with_bpe
|
||||
|
||||
|
||||
def tokenize_by_CJK_char(line: str) -> str:
|
||||
"""
|
||||
Tokenize a line of text with CJK char.
|
||||
|
||||
Note: All return charaters will be upper case.
|
||||
|
||||
Example:
|
||||
input = "你好世界是 hello world 的中文"
|
||||
output = "你 好 世 界 是 HELLO WORLD 的 中 文"
|
||||
|
||||
Args:
|
||||
line:
|
||||
The input text.
|
||||
|
||||
Return:
|
||||
A new string tokenize by CJK char.
|
||||
"""
|
||||
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
|
||||
pattern = re.compile(
|
||||
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
|
||||
)
|
||||
chars = pattern.split(line.strip().upper())
|
||||
return " ".join([w.strip() for w in chars if w.strip()])
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
@ -1764,3 +1789,34 @@ def parse_fsa_timestamps_and_texts(
|
||||
utt_time_pairs.append(list(zip(start, end)))
|
||||
|
||||
return utt_time_pairs, utt_words
|
||||
|
||||
|
||||
# Copied from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
|
||||
def is_cjk(character):
|
||||
"""
|
||||
Python port of Moses' code to check for CJK character.
|
||||
|
||||
>>> is_cjk(u'\u33fe')
|
||||
True
|
||||
>>> is_cjk(u'\uFE5F')
|
||||
False
|
||||
|
||||
:param character: The character that needs to be checked.
|
||||
:type character: char
|
||||
:return: bool
|
||||
"""
|
||||
return any(
|
||||
[
|
||||
start <= ord(character) <= end
|
||||
for start, end in [
|
||||
(4352, 4607),
|
||||
(11904, 42191),
|
||||
(43072, 43135),
|
||||
(44032, 55215),
|
||||
(63744, 64255),
|
||||
(65072, 65103),
|
||||
(65381, 65500),
|
||||
(131072, 196607),
|
||||
]
|
||||
]
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user