Training with byte level BPE (AIShell) (#986)

* copy files from zipformer librispeech

* Add byte bpe training for aishell

* compile LG graph

* Support LG decoding

* Minor fixes

* black

* Minor fixes

* export & fix pretrain.py

* fix black

* Update RESULTS.md

* Fix export.py
This commit is contained in:
Wei Kang 2023-05-04 19:16:17 +08:00 committed by GitHub
parent 61ec3a7a8f
commit 80156dda09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 3992 additions and 47 deletions

View File

@ -2,6 +2,57 @@
### Aishell training result(Stateless Transducer)
#### Pruned transducer stateless 7 (zipformer)
See <https://github.com/k2-fsa/icefall/pull/986>
[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe)
**Note**: The modeling units are byte level BPEs
The best results I have gotten are:
Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments
-- | -- | -- | -- | -- | --
500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29
The training command:
```
export CUDA_VISIBLE_DEVICES="4,5,6,7"
./pruned_transducer_stateless7_bbpe/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--max-duration 800 \
--bpe-model data/lang_bbpe_500/bbpe.model \
--exp-dir pruned_transducer_stateless7_bbpe/exp \
--lr-epochs 6 \
--master-port 12535
```
The decoding command:
```
for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 48 \
--avg 29 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-sym-per-frame 1 \
--ngram-lm-scale 0.25 \
--ilme-scale 0.2 \
--bpe-model data/lang_bbpe_500/bbpe.model \
--max-duration 2000 \
--decoding-method $m
done
```
The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
#### Pruned transducer stateless 3
See <https://github.com/k2-fsa/icefall/pull/436>
@ -75,7 +126,7 @@ for epoch in 29; do
done
```
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
please use the following command:

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`:
- tokens.txt
"""
import argparse
import re
from pathlib import Path
from typing import Dict, List
@ -189,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
return tokens
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
return parser.parse_args()
def main():
lang_dir = Path("data/lang_char")
args = get_args()
lang_dir = Path(args.lang_dir)
text_file = lang_dir / "text"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")

View File

@ -0,0 +1,267 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/bbpe.model,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import sentencepiece as spm
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
from icefall.byte_utils import byte_encode
from icefall.utils import str2bool, tokenize_by_CJK_char
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def generate_lexicon(
model_file: str, words: List[str], oov: str
) -> Tuple[Lexicon, Dict[str, int]]:
"""Generate a lexicon from a BPE model.
Args:
model_file:
Path to a sentencepiece model.
words:
A list of strings representing words.
oov:
The out of vocabulary word in lexicon.
Returns:
Return a tuple with two elements:
- A dict whose keys are words and values are the corresponding
word pieces.
- A dict representing the token symbol, mapping from tokens to IDs.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
# Convert word to word piece IDs instead of word piece strings
# to avoid OOV tokens.
encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words]
words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int)
# Now convert word piece IDs back to word piece strings.
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
lexicon = []
for word, pieces in zip(words, words_pieces):
lexicon.append((word, pieces))
lexicon.append((oov, ["", sp.id_to_piece(sp.unk_id())]))
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
return lexicon, token2id
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
parser.add_argument(
"--oov",
type=str,
default="<UNK>",
help="The out of vocabulary word in lexicon.",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
See "test/test_bpe_lexicon.py" for usage.
""",
)
return parser.parse_args()
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bbpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,113 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
import re
import shutil
import tempfile
from pathlib import Path
import sentencepiece as spm
from icefall import byte_encode, tokenize_by_CJK_char
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def _convert_to_bchar(in_path: str, out_path: str):
with open(out_path, "w") as f:
for line in open(in_path, "r").readlines():
f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n")
def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
temp = tempfile.NamedTemporaryFile()
train_text = temp.name
_convert_to_bchar(args.transcript, train_text)
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
else:
print(f"{model_file} exists - skipping")
return
shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
if __name__ == "__main__":
main()

View File

@ -35,6 +35,15 @@ dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bbpe_xxx,
# data/lang_bbpe_yyy if the array contains xxx, yyy
vocab_sizes=(
# 2000
# 1000
500
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
@ -47,20 +56,6 @@ log() {
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM"
# We assume that you have installed the git-lfs, if not, you could install it
# using: `sudo apt-get install git-lfs && git-lfs install`
git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
pushd $dl_dir/lm
git lfs pull --include "3-gram.unpruned.arpa"
popd
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data"
@ -134,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
lang_phone_dir=data/lang_phone
lang_char_dir=data/lang_char
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
mkdir -p $lang_phone_dir
@ -183,45 +177,107 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
fi
lang_char_dir=data/lang_char
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
mkdir -p $lang_char_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp $lang_phone_dir/words.txt $lang_char_dir
# The transcripts in training set, generated in stage 5
cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text
cut -d " " -f 2- > $lang_char_dir/text
(echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
> $lang_char_dir/words.txt
cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
| awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
num_lines=$(< $lang_char_dir/words.txt wc -l)
(echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
>> $lang_char_dir/words.txt
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py
./local/prepare_char.py --lang-dir $lang_char_dir
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
log "Stage 7: Prepare Byte BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
mkdir -p $lang_dir
cp $lang_char_dir/words.txt $lang_dir
cp $lang_char_dir/text $lang_dir
if [ ! -f $lang_dir/bbpe.model ]; then
./local/train_bbpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/text
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bbpe.py --lang-dir $lang_dir
fi
done
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare G"
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# Train LM on transcripts
if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
python3 ./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_char_dir/transcript_words.txt \
-lm data/lm/3-gram.unpruned.arpa
fi
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="$lang_phone_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt
python3 -m kaldilm \
--read-symbol-table="$lang_char_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Compile HLG"
./local/compile_hlg.py --lang-dir $lang_phone_dir
./local/compile_hlg.py --lang-dir $lang_char_dir
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile LG & HLG"
./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
done
./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
done
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Generate LM training data"
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Generate LM training data"
log "Processing char based data"
out_dir=data/lm_training_char
@ -267,8 +323,8 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Sort LM training data"
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Sort LM training data"
# Sort LM training data by sentence length in descending order
# for ease of training.
#
@ -295,7 +351,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
--out-statistics $out_dir/statistics-test.txt
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 11: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,819 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import (
LmScorer,
NgramLm,
byte_encode,
smart_byte_decode,
tokenize_by_CJK_char,
)
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bbpe_500/bbpe.model",
help="Path to the byte BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bbpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
If you use fast_beam_search_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.25,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.decoding_method == "fast_beam_search_LG":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
subtract_ilme=True,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
ref_texts = []
for tx in supervisions["text"]:
ref_texts.append(byte_encode(tokenize_by_CJK_char(tx)))
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(ref_texts),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
key += f"_ilme_scale_{params.ilme_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
params.suffix += f"-ilme-scale-{params.ilme_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bbpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
dev_cuts = aishell.valid_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
dev_dl = aishell.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_dls = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,320 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless7_bbpe/export.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--bpe-model data/lang_bbpe_500/bbpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./pruned_transducer_stateless7_bbpe/export.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--bpe-model data/lang_bbpe_500/bbpe.model \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/aishell/ASR
./pruned_transducer_stateless7_bbpe/decode.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bbpe_500/bbpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
# You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_bbpe/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bbpe_500/bbpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bbpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,274 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7_bbpe/export.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--bpe-model data/lang_bbpe_500/bbpe.model \
--epoch 49 \
--avg 28 \
--jit 1
Usage of this script:
./pruned_transducer_stateless7_bbpe/jit_pretrained.py \
--nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \
--bpe-model ./data/lang_bbpe_500/bbpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall import smart_byte_decode
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
model: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features,
x_lens=feature_lengths,
)
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = smart_byte_decode(sp.decode(hyp))
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1,345 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_bbpe/export.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--bpe-model data/lang_bbpe_500/bbpe.model \
--epoch 48 \
--avg 29
Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_bbpe/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
--bpe-model ./data/lang_bbpe_500/bbpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./pruned_transducer_stateless7_bbpe/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
--bpe-model ./data/lang_bbpe_500/bbpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./pruned_transducer_stateless7_bbpe/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
--bpe-model ./data/lang_bbpe_500/bbpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by
./pruned_transducer_stateless7_bbpe/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall import smart_byte_decode
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py

View File

@ -21,7 +21,7 @@ import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import List
from typing import Any, Dict, List, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
@ -181,7 +181,16 @@ class AishellAsrDataModule:
"with training dataset. ",
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
@ -277,6 +286,10 @@ class AishellAsrDataModule:
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
train_dl = DataLoader(
train,
sampler=train_sampler,
@ -325,7 +338,7 @@ class AishellAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats

View File

@ -45,11 +45,18 @@ def get_args():
help="""Input and output directory.
""",
)
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str) -> k2.Fsa:
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
@ -61,15 +68,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"data/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
logging.info(f"Loading {lm}.fst.txt")
with open(f"data/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
torch.save(G.as_dict(), f"data/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
@ -126,7 +133,7 @@ def main():
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir)
LG = compile_LG(lang_dir, args.lm)
logging.info(f"Saving LG.pt to {lang_dir}")
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")

View File

@ -47,6 +47,8 @@ def fast_beam_search_one_best(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@ -88,6 +90,8 @@ def fast_beam_search_one_best(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale,
)
best_path = one_best_decoding(lattice)
@ -428,6 +432,8 @@ def fast_beam_search(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
@ -498,6 +504,17 @@ def fast_beam_search(
)
logits = logits.squeeze(1).squeeze(1)
log_probs = (logits / temperature).log_softmax(dim=-1)
if subtract_ilme:
ilme_logits = model.joiner(
torch.zeros_like(
current_encoder_out, device=current_encoder_out.device
).unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
log_probs -= ilme_scale * ilme_log_probs
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())

View File

@ -8,6 +8,12 @@ from . import (
utils
)
from .byte_utils import (
byte_decode,
byte_encode,
smart_byte_decode,
)
from .checkpoint import (
average_checkpoints,
find_checkpoints,
@ -49,6 +55,7 @@ from .utils import (
get_alignments,
get_executor,
get_texts,
is_cjk,
is_jit_tracing,
is_module_available,
l1_norm,
@ -64,6 +71,7 @@ from .utils import (
store_transcripts,
str2bool,
subsequent_chunk_mask,
tokenize_by_CJK_char,
write_error_stats,
)

311
icefall/byte_utils.py Normal file
View File

@ -0,0 +1,311 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py
import re
import unicodedata
WHITESPACE_NORMALIZER = re.compile(r"\s+")
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
PRINTABLE_BASE_CHARS = [
256,
257,
258,
259,
260,
261,
262,
263,
264,
265,
266,
267,
268,
269,
270,
271,
272,
273,
274,
275,
276,
277,
278,
279,
280,
281,
282,
283,
284,
285,
286,
287,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
124,
125,
126,
288,
289,
290,
291,
292,
293,
294,
295,
296,
297,
298,
299,
300,
301,
302,
303,
304,
305,
308,
309,
310,
311,
312,
313,
314,
315,
316,
317,
318,
321,
322,
323,
324,
325,
326,
327,
328,
330,
331,
332,
333,
334,
335,
336,
337,
338,
339,
340,
341,
342,
343,
344,
345,
346,
347,
348,
349,
350,
351,
352,
353,
354,
355,
356,
357,
358,
359,
360,
361,
362,
363,
364,
365,
366,
367,
368,
369,
370,
371,
372,
373,
374,
375,
376,
377,
378,
379,
380,
381,
382,
384,
385,
386,
387,
388,
389,
390,
391,
392,
393,
394,
395,
396,
397,
398,
399,
400,
401,
402,
403,
404,
405,
406,
407,
408,
409,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
]
for c in PRINTABLE_BASE_CHARS:
assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c
BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
def byte_encode(x: str) -> str:
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
def byte_decode(x: str) -> str:
try:
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
except ValueError:
return ""
def smart_byte_decode(x: str) -> str:
output = byte_decode(x)
if output == "":
# DP the best recovery (max valid chars) if it's broken
n_bytes = len(x)
f = [0 for _ in range(n_bytes + 1)]
pt = [0 for _ in range(n_bytes + 1)]
for i in range(1, n_bytes + 1):
f[i], pt[i] = f[i - 1], i - 1
for j in range(1, min(4, i) + 1):
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
f[i], pt[i] = f[i - j] + 1, i - j
cur_pt = n_bytes
while cur_pt > 0:
if f[cur_pt] == f[pt[cur_pt]] + 1:
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
cur_pt = pt[cur_pt]
return output

View File

@ -87,7 +87,7 @@ class RnnLmModel(torch.nn.Module):
h0: torch.Tensor,
c0: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
'''
"""
Args:
x:
A 2-D tensor of shape (N, L). We won't prepend it with SOS.
@ -103,7 +103,7 @@ class RnnLmModel(torch.nn.Module):
- negative loglike (nll), a 1-D tensor of shape (N,)
- next_h0, a 3-D tensor with the same shape as h0
- next_c0, a 3-D tensor with the same shape as c0
'''
"""
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
assert x.shape == y.shape, (x.shape, y.shape)

View File

@ -1306,6 +1306,31 @@ def tokenize_by_bpe_model(
return txt_with_bpe
def tokenize_by_CJK_char(line: str) -> str:
"""
Tokenize a line of text with CJK char.
Note: All return charaters will be upper case.
Example:
input = "你好世界是 hello world 的中文"
output = "你 好 世 界 是 HELLO WORLD 的 中 文"
Args:
line:
The input text.
Return:
A new string tokenize by CJK char.
"""
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
pattern = re.compile(
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
)
chars = pattern.split(line.strip().upper())
return " ".join([w.strip() for w in chars if w.strip()])
def display_and_save_batch(
batch: dict,
params: AttributeDict,
@ -1764,3 +1789,34 @@ def parse_fsa_timestamps_and_texts(
utt_time_pairs.append(list(zip(start, end)))
return utt_time_pairs, utt_words
# Copied from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
def is_cjk(character):
"""
Python port of Moses' code to check for CJK character.
>>> is_cjk(u'\u33fe')
True
>>> is_cjk(u'\uFE5F')
False
:param character: The character that needs to be checked.
:type character: char
:return: bool
"""
return any(
[
start <= ord(character) <= end
for start, end in [
(4352, 4607),
(11904, 42191),
(43072, 43135),
(44032, 55215),
(63744, 64255),
(65072, 65103),
(65381, 65500),
(131072, 196607),
]
]
)