mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-11 18:14:19 +00:00
Merge branch 'k2-fsa:master' into tiny
This commit is contained in:
commit
a67b673c98
@ -15,6 +15,8 @@ It uses pruned RNN-T.
|
|||||||
|------------------------|------|------|---------------------------------------|
|
|------------------------|------|------|---------------------------------------|
|
||||||
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
|
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
|
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
|
| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
|
| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
|
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
|
||||||
|
|
||||||
Training command is:
|
Training command is:
|
||||||
@ -73,6 +75,78 @@ for epoch in 29; do
|
|||||||
done
|
done
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
|
||||||
|
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
|
||||||
|
please use the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# download pre-trained model
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
|
||||||
|
|
||||||
|
aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/
|
||||||
|
|
||||||
|
pushd ${aishell_exp}/exp
|
||||||
|
ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
# download RNN LM
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm
|
||||||
|
rnnlm_dir=icefall-aishell-rnn-lm
|
||||||
|
|
||||||
|
# RNNLM shallow fusion
|
||||||
|
for lm_scale in $(seq 0.26 0.02 0.34); do
|
||||||
|
python ./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--lang-dir ${aishell_exp}/data/lang_char \
|
||||||
|
--exp-dir ${aishell_exp}/exp \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||||
|
--use-shallow-fusion 1 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-exp-dir ${rnnlm_dir}/exp \
|
||||||
|
--lm-epoch 99 \
|
||||||
|
--lm-scale $lm_scale \
|
||||||
|
--lm-avg 1 \
|
||||||
|
--rnn-lm-embedding-dim 2048 \
|
||||||
|
--rnn-lm-hidden-dim 2048 \
|
||||||
|
--rnn-lm-num-layers 2 \
|
||||||
|
--lm-vocab-size 4336
|
||||||
|
done
|
||||||
|
|
||||||
|
# RNNLM Low-order density ratio (LODR) with a 2-gram
|
||||||
|
|
||||||
|
cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt
|
||||||
|
|
||||||
|
for lm_scale in 0.48; do
|
||||||
|
for LODR_scale in -0.28; do
|
||||||
|
python ./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--lang-dir ${aishell_exp}/data/lang_char \
|
||||||
|
--exp-dir ${aishell_exp}/exp \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--decoding-method modified_beam_search_LODR \
|
||||||
|
--use-shallow-fusion 1 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-exp-dir ${rnnlm_dir}/exp \
|
||||||
|
--lm-epoch 99 \
|
||||||
|
--lm-scale $lm_scale \
|
||||||
|
--lm-avg 1 \
|
||||||
|
--rnn-lm-embedding-dim 2048 \
|
||||||
|
--rnn-lm-hidden-dim 2048 \
|
||||||
|
--rnn-lm-num-layers 2 \
|
||||||
|
--lm-vocab-size 4336 \
|
||||||
|
--tokens-ngram 2 \
|
||||||
|
--backoff-id 4336 \
|
||||||
|
--ngram-lm-scale $LODR_scale
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
Pretrained models, training logs, decoding logs, and decoding results
|
Pretrained models, training logs, decoding logs, and decoding results
|
||||||
are available at
|
are available at
|
||||||
<https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20>
|
<https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20>
|
||||||
|
0
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file → Executable file
0
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file → Executable file
1
egs/aishell/ASR/local/sort_lm_training_data.py
Symbolic link
1
egs/aishell/ASR/local/sort_lm_training_data.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/sort_lm_training_data.py
|
@ -231,11 +231,13 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
|
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# training words
|
||||||
./local/prepare_char_lm_training_data.py \
|
./local/prepare_char_lm_training_data.py \
|
||||||
--lang-char data/lang_char \
|
--lang-char data/lang_char \
|
||||||
--lm-data $dl_dir/lm/aishell-train-word.txt \
|
--lm-data $dl_dir/lm/aishell-train-word.txt \
|
||||||
--lm-archive $out_dir/lm_data.pt
|
--lm-archive $out_dir/lm_data.pt
|
||||||
|
|
||||||
|
# valid words
|
||||||
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
|
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
|
||||||
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
|
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
|
||||||
@ -249,6 +251,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
--lm-data $dl_dir/lm/aishell-valid-word.txt \
|
--lm-data $dl_dir/lm/aishell-valid-word.txt \
|
||||||
--lm-archive $out_dir/lm_data_valid.pt
|
--lm-archive $out_dir/lm_data_valid.pt
|
||||||
|
|
||||||
|
# test words
|
||||||
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
|
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
|
||||||
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
|
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
|
||||||
@ -303,9 +306,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
|||||||
--hidden-dim 512 \
|
--hidden-dim 512 \
|
||||||
--num-layers 2 \
|
--num-layers 2 \
|
||||||
--batch-size 400 \
|
--batch-size 400 \
|
||||||
--exp-dir rnnlm_char/exp \
|
--exp-dir rnnlm_char/exp_aishell1_small \
|
||||||
--lm-data data/lm_training_char/sorted_lm_data.pt \
|
--lm-data data/lm_char/sorted_lm_data_aishell1.pt \
|
||||||
--lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \
|
--lm-data-valid data/lm_char/sorted_lm_data_valid.pt \
|
||||||
--vocab-size 4336 \
|
--vocab-size 4336 \
|
||||||
--master-port 12345
|
--master-port 12345
|
||||||
fi
|
fi
|
||||||
|
@ -54,6 +54,40 @@ Usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) modified beam search (with LM shallow fusion)
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||||
|
--beam-size 4 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-scale 0.3 \
|
||||||
|
--lm-exp-dir /path/to/LM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1
|
||||||
|
|
||||||
|
(6) modified beam search with LM shallow fusion + LODR
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--decoding-method modified_beam_search_LODR \
|
||||||
|
--beam-size 4 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-scale 0.48 \
|
||||||
|
--lm-exp-dir /path/to/LM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1
|
||||||
|
--tokens-ngram 2 \
|
||||||
|
--ngram-lm-scale -0.28 \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -74,9 +108,12 @@ from beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
|
modified_beam_search_lm_shallow_fusion,
|
||||||
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -212,6 +249,60 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-shallow-fusion",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Use neural network LM for shallow fusion.
|
||||||
|
If you want to use LODR, you will also need to set this to true
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="""The scale of the neural network LM
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-ngram",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="""Token Ngram used for rescoring.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""ID of the backoff symbol.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -223,6 +314,9 @@ def decode_one_batch(
|
|||||||
token_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -287,6 +381,24 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||||
|
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search_LODR":
|
||||||
|
hyp_tokens = modified_beam_search_LODR(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LODR_lm=ngram_lm,
|
||||||
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hyp_tokens = []
|
hyp_tokens = []
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
@ -334,6 +446,9 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
token_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -379,6 +494,9 @@ def decode_dataset(
|
|||||||
token_table=token_table,
|
token_table=token_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -445,6 +563,7 @@ def save_results(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
AsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
args.lang_dir = Path(args.lang_dir)
|
||||||
@ -458,6 +577,8 @@ def main():
|
|||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_LODR",
|
||||||
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -479,6 +600,19 @@ def main():
|
|||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
if "ngram" in params.decoding_method:
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
if params.lm_type == "rnn":
|
||||||
|
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
|
||||||
|
elif params.lm_type == "transformer":
|
||||||
|
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
|
if "LODR" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -588,6 +722,35 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
# only load N-gram LM when needed
|
||||||
|
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||||
|
lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt"
|
||||||
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
|
ngram_lm = NgramLm(
|
||||||
|
lm_filename,
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
ngram_lm_scale = params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
ngram_lm = None
|
||||||
|
ngram_lm_scale = None
|
||||||
|
|
||||||
|
# only load the neural network LM if doing shallow fusion
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
LM = LmScorer(
|
||||||
|
lm_type=params.lm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
LM.to(device)
|
||||||
|
LM.eval()
|
||||||
|
|
||||||
|
else:
|
||||||
|
LM = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -610,6 +773,9 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
token_table=lexicon.token_table,
|
token_table=lexicon.token_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -1 +0,0 @@
|
|||||||
/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
|
|
@ -48,6 +48,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -244,6 +245,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
1
egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless3/lstmp.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
107
egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py
Executable file
107
egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py
Executable file
@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the CommonVoice dataset.
|
||||||
|
It looks for manifests in the directory data/${lang}/manifests.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/${lang}/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from filter_cuts import filter_cuts
|
||||||
|
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
help="""Language of Common Voice""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_commonvoice_dev_test(language: str):
|
||||||
|
src_dir = Path(f"data/{language}/manifests")
|
||||||
|
output_dir = Path(f"data/{language}/fbank")
|
||||||
|
num_workers = 42
|
||||||
|
batch_duration = 600
|
||||||
|
|
||||||
|
subsets = ("dev", "test")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
for partition in subsets:
|
||||||
|
cuts_path = output_dir / f"cv-{language}_cuts_{partition}.jsonl.gz"
|
||||||
|
if cuts_path.is_file():
|
||||||
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_cuts_path = output_dir / f"cv-{language}_cuts_{partition}_raw.jsonl.gz"
|
||||||
|
|
||||||
|
logging.info(f"Loading {raw_cuts_path}")
|
||||||
|
cut_set = CutSet.from_file(raw_cuts_path)
|
||||||
|
|
||||||
|
logging.info("Splitting cuts into smaller chunks")
|
||||||
|
cut_set = cut_set.trim_to_supervisions(
|
||||||
|
keep_overlapping=False, min_duration=None
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Computing features")
|
||||||
|
cut_set = cut_set.compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{output_dir}/cv-{language}_feats_{partition}",
|
||||||
|
num_workers=num_workers,
|
||||||
|
batch_duration=batch_duration,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Saving to {cuts_path}")
|
||||||
|
cut_set.to_file(cuts_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
compute_fbank_commonvoice_dev_test(language=args.language)
|
157
egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
Executable file
157
egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
Executable file
@ -0,0 +1,157 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
KaldifeatFbank,
|
||||||
|
KaldifeatFbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
set_audio_duration_mismatch_tolerance,
|
||||||
|
set_caching_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
help="""Language of Common Voice""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Number of dataloading workers used for reading the audio.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-duration",
|
||||||
|
type=float,
|
||||||
|
default=600.0,
|
||||||
|
help="The maximum number of audio seconds in a batch."
|
||||||
|
"Determines batch size dynamically.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-splits",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="The number of splits of the train subset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Process pieces starting from this number (inclusive).",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Stop processing pieces until this number (exclusive).",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_commonvoice_splits(args):
|
||||||
|
subset = "train"
|
||||||
|
num_splits = args.num_splits
|
||||||
|
language = args.language
|
||||||
|
output_dir = f"data/{language}/fbank/{subset}_split_{num_splits}"
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||||
|
|
||||||
|
num_digits = len(str(num_splits))
|
||||||
|
|
||||||
|
start = args.start
|
||||||
|
stop = args.stop
|
||||||
|
if stop < start:
|
||||||
|
stop = num_splits
|
||||||
|
|
||||||
|
stop = min(stop, num_splits)
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||||
|
set_caching_enabled(False)
|
||||||
|
for i in range(start, stop):
|
||||||
|
idx = f"{i + 1}".zfill(num_digits)
|
||||||
|
logging.info(f"Processing {idx}/{num_splits}")
|
||||||
|
|
||||||
|
cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz"
|
||||||
|
if cuts_path.is_file():
|
||||||
|
logging.info(f"{cuts_path} exists - skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_cuts_path = output_dir / f"cv-{language}_cuts_{subset}_raw.{idx}.jsonl.gz"
|
||||||
|
|
||||||
|
logging.info(f"Loading {raw_cuts_path}")
|
||||||
|
cut_set = CutSet.from_file(raw_cuts_path)
|
||||||
|
|
||||||
|
logging.info("Splitting cuts into smaller chunks.")
|
||||||
|
cut_set = cut_set.trim_to_supervisions(
|
||||||
|
keep_overlapping=False, min_duration=None
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Computing features")
|
||||||
|
cut_set = cut_set.compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{output_dir}/cv-{language}_feats_{subset}_{idx}",
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
batch_duration=args.batch_duration,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Saving to {cuts_path}")
|
||||||
|
cut_set.to_file(cuts_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
compute_fbank_commonvoice_splits(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/commonvoice/ASR/local/compute_fbank_musan.py
Symbolic link
1
egs/commonvoice/ASR/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/compute_fbank_musan.py
|
1
egs/commonvoice/ASR/local/filter_cuts.py
Symbolic link
1
egs/commonvoice/ASR/local/filter_cuts.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/filter_cuts.py
|
119
egs/commonvoice/ASR/local/preprocess_commonvoice.py
Executable file
119
egs/commonvoice/ASR/local/preprocess_commonvoice.py
Executable file
@ -0,0 +1,119 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from lhotse import CutSet, SupervisionSegment
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
help="""Language of Common Voice""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_commonvoice(
|
||||||
|
language: str,
|
||||||
|
dataset: Optional[str] = None,
|
||||||
|
):
|
||||||
|
src_dir = Path(f"data/{language}/manifests")
|
||||||
|
output_dir = Path(f"data/{language}/fbank")
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
dataset_parts = (
|
||||||
|
"dev",
|
||||||
|
"test",
|
||||||
|
"train",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataset_parts = dataset.split(" ", -1)
|
||||||
|
|
||||||
|
logging.info("Loading manifest")
|
||||||
|
prefix = f"cv-{language}"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
manifests = read_manifests_if_cached(
|
||||||
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
|
for partition, m in manifests.items():
|
||||||
|
logging.info(f"Processing {partition}")
|
||||||
|
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
|
||||||
|
if raw_cuts_path.is_file():
|
||||||
|
logging.info(f"{partition} already exists - skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create long-recording cut manifests.
|
||||||
|
cut_set = CutSet.from_manifests(
|
||||||
|
recordings=m["recordings"],
|
||||||
|
supervisions=m["supervisions"],
|
||||||
|
).resample(16000)
|
||||||
|
|
||||||
|
# Run data augmentation that needs to be done in the
|
||||||
|
# time domain.
|
||||||
|
if "train" in partition:
|
||||||
|
logging.info(
|
||||||
|
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
|
||||||
|
"(Perturbing may take 2 minutes and saving may take 7 minutes)"
|
||||||
|
)
|
||||||
|
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
|
logging.info(f"Saving to {raw_cuts_path}")
|
||||||
|
cut_set.to_file(raw_cuts_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
preprocess_commonvoice(
|
||||||
|
language=args.language,
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
||||||
|
logging.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
156
egs/commonvoice/ASR/prepare.sh
Executable file
156
egs/commonvoice/ASR/prepare.sh
Executable file
@ -0,0 +1,156 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
nj=16
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
# Split data/${lang}set to this number of pieces
|
||||||
|
# This is to avoid OOM during feature extraction.
|
||||||
|
num_splits=1000
|
||||||
|
|
||||||
|
# We assume dl_dir (download dir) contains the following
|
||||||
|
# directories and files. If not, they will be downloaded
|
||||||
|
# by this script automatically.
|
||||||
|
#
|
||||||
|
# - $dl_dir/$release/$lang
|
||||||
|
# This directory contains the following files downloaded from
|
||||||
|
# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz
|
||||||
|
#
|
||||||
|
# - clips
|
||||||
|
# - dev.tsv
|
||||||
|
# - invalidated.tsv
|
||||||
|
# - other.tsv
|
||||||
|
# - reported.tsv
|
||||||
|
# - test.tsv
|
||||||
|
# - train.tsv
|
||||||
|
# - validated.tsv
|
||||||
|
#
|
||||||
|
# - $dl_dir/musan
|
||||||
|
# This directory contains the following directories downloaded from
|
||||||
|
# http://www.openslr.org/17/
|
||||||
|
#
|
||||||
|
# - music
|
||||||
|
# - noise
|
||||||
|
# - speech
|
||||||
|
|
||||||
|
dl_dir=$PWD/download
|
||||||
|
release=cv-corpus-13.0-2023-03-09
|
||||||
|
lang=en
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
# vocab size for sentence piece models.
|
||||||
|
# It will generate data/${lang}/lang_bpe_xxx,
|
||||||
|
# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy
|
||||||
|
vocab_sizes=(
|
||||||
|
# 5000
|
||||||
|
# 2000
|
||||||
|
# 1000
|
||||||
|
500
|
||||||
|
)
|
||||||
|
|
||||||
|
# All files generated by this script are saved in "data/${lang}".
|
||||||
|
# You can safely remove "data/${lang}" and rerun this script to regenerate it.
|
||||||
|
mkdir -p data/${lang}
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
log "dl_dir: $dl_dir"
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Stage 0: Download data"
|
||||||
|
|
||||||
|
# If you have pre-downloaded it to /path/to/$release,
|
||||||
|
# you can create a symlink
|
||||||
|
#
|
||||||
|
# ln -sfv /path/to/$release $dl_dir/$release
|
||||||
|
#
|
||||||
|
if [ ! -d $dl_dir/$release/$lang/clips ]; then
|
||||||
|
lhotse download commonvoice --languages $lang --release $release $dl_dir
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If you have pre-downloaded it to /path/to/musan,
|
||||||
|
# you can create a symlink
|
||||||
|
#
|
||||||
|
# ln -sfv /path/to/musan $dl_dir/
|
||||||
|
#
|
||||||
|
if [ ! -d $dl_dir/musan ]; then
|
||||||
|
lhotse download musan $dl_dir
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Prepare CommonVoice manifest"
|
||||||
|
# We assume that you have downloaded the CommonVoice corpus
|
||||||
|
# to $dl_dir/$release
|
||||||
|
mkdir -p data/${lang}/manifests
|
||||||
|
if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then
|
||||||
|
lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests
|
||||||
|
touch data/${lang}/manifests/.cv-${lang}.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "Stage 2: Prepare musan manifest"
|
||||||
|
# We assume that you have downloaded the musan corpus
|
||||||
|
# to data/musan
|
||||||
|
mkdir -p data/manifests
|
||||||
|
if [ ! -e data/manifests/.musan.done ]; then
|
||||||
|
lhotse prepare musan $dl_dir/musan data/manifests
|
||||||
|
touch data/manifests/.musan.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 3: Preprocess CommonVoice manifest"
|
||||||
|
if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then
|
||||||
|
./local/preprocess_commonvoice.py --language $lang
|
||||||
|
touch data/${lang}/fbank/.preprocess_complete
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Compute fbank for dev and test subsets of CommonVoice"
|
||||||
|
mkdir -p data/${lang}/fbank
|
||||||
|
if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then
|
||||||
|
./local/compute_fbank_commonvoice_dev_test.py --language $lang
|
||||||
|
touch data/${lang}/fbank/.cv-${lang}_dev_test.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "Stage 5: Split train subset into ${num_splits} pieces"
|
||||||
|
split_dir=data/${lang}/fbank/train_split_${num_splits}
|
||||||
|
if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then
|
||||||
|
lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir
|
||||||
|
touch $split_dir/.cv-${lang}_train_split.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
log "Stage 6: Compute features for train subset of CommonVoice"
|
||||||
|
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
|
||||||
|
./local/compute_fbank_commonvoice_splits.py \
|
||||||
|
--num-workers $nj \
|
||||||
|
--batch-duration 600 \
|
||||||
|
--start 0 \
|
||||||
|
--num-splits $num_splits \
|
||||||
|
--language $lang
|
||||||
|
touch data/${lang}/fbank/.cv-${lang}_train.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
|
log "Stage 7: Compute fbank for musan"
|
||||||
|
mkdir -p data/fbank
|
||||||
|
if [ ! -e data/fbank/.musan.done ]; then
|
||||||
|
./local/compute_fbank_musan.py
|
||||||
|
touch data/fbank/.musan.done
|
||||||
|
fi
|
||||||
|
fi
|
1
egs/commonvoice/ASR/shared
Symbolic link
1
egs/commonvoice/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../icefall/shared/
|
1
egs/gigaspeech/ASR/local/generate_unique_lexicon.py
Symbolic link
1
egs/gigaspeech/ASR/local/generate_unique_lexicon.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/generate_unique_lexicon.py
|
@ -19,40 +19,40 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -76,12 +76,17 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from gigaspeech_scoring import asr_text_post_processing
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
from icefall.checkpoint import (
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,9 +99,9 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=29,
|
default=30,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 0.
|
Note: Epoch counts from 1.
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -119,6 +124,17 @@ def get_parser():
|
|||||||
"'--epoch' and '--iter'",
|
"'--epoch' and '--iter'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -464,6 +480,9 @@ def main():
|
|||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -476,7 +495,7 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
@ -486,37 +505,85 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.iter > 0:
|
if not params.use_averaged_model:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
if params.iter > 0:
|
||||||
: params.avg
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
]
|
: params.avg
|
||||||
if len(filenames) == 0:
|
]
|
||||||
raise ValueError(
|
if len(filenames) == 0:
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
raise ValueError(
|
||||||
)
|
f"No checkpoints found for"
|
||||||
elif len(filenames) < params.avg:
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
raise ValueError(
|
)
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
elif len(filenames) < params.avg:
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
raise ValueError(
|
||||||
)
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
logging.info(f"averaging {filenames}")
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
model.to(device)
|
)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
logging.info(f"averaging {filenames}")
|
||||||
elif params.avg == 1:
|
model.to(device)
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
if params.iter > 0:
|
||||||
filenames = []
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
for i in range(start, params.epoch + 1):
|
: params.avg + 1
|
||||||
if start >= 0:
|
]
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
if len(filenames) == 0:
|
||||||
logging.info(f"averaging {filenames}")
|
raise ValueError(
|
||||||
model.to(device)
|
f"No checkpoints found for"
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
@ -19,31 +19,30 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./pruned_transducer_stateless2/train.py \
|
||||||
--world-size 4 \
|
--world-size 8 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir pruned_transducer_stateless2/exp \
|
||||||
--full-libri 1 \
|
--max-duration 120
|
||||||
--max-duration 300
|
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./pruned_transducer_stateless2/train.py \
|
||||||
--world-size 4 \
|
--world-size 8 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--use_fp16 1 \
|
--use_fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir pruned_transducer_stateless2/exp \
|
||||||
--full-libri 1 \
|
--max-duration 200
|
||||||
--max-duration 550
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -72,7 +71,10 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
from icefall.checkpoint import (
|
||||||
|
save_checkpoint_with_global_batch_idx,
|
||||||
|
update_averaged_model,
|
||||||
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -116,10 +118,10 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--start-epoch",
|
"--start-epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=1,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from this epoch.
|
||||||
If it is positive, it will load checkpoint from
|
If larger than 1, it will load checkpoint from
|
||||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
exp-dir/epoch-{start_epoch-1}.pt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -242,7 +244,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--keep-last-k",
|
"--keep-last-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=30,
|
||||||
help="""Only keep this number of checkpoints on disk.
|
help="""Only keep this number of checkpoints on disk.
|
||||||
For instance, if it is 3, there are only 3 checkpoints
|
For instance, if it is 3, there are only 3 checkpoints
|
||||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||||
@ -250,6 +252,19 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--average-period",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Update the averaged model, namely `model_avg`, after processing
|
||||||
|
this number of batches. `model_avg` is a separate version of model,
|
||||||
|
in which each floating-point parameter is the average of all the
|
||||||
|
parameters from the start of training. Each time we take the average,
|
||||||
|
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||||
|
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-fp16",
|
"--use-fp16",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -387,6 +402,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
model_avg: nn.Module = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
@ -394,7 +410,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
If params.start_batch is positive, it will load the checkpoint from
|
If params.start_batch is positive, it will load the checkpoint from
|
||||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||||
params.start_epoch is positive, it will load the checkpoint from
|
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
`params.start_epoch - 1`.
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||||
@ -406,6 +422,8 @@ def load_checkpoint_if_available(
|
|||||||
The return value of :func:`get_params`.
|
The return value of :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The training model.
|
The training model.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer that we are using.
|
The optimizer that we are using.
|
||||||
scheduler:
|
scheduler:
|
||||||
@ -415,7 +433,7 @@ def load_checkpoint_if_available(
|
|||||||
"""
|
"""
|
||||||
if params.start_batch > 0:
|
if params.start_batch > 0:
|
||||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||||
elif params.start_epoch > 0:
|
elif params.start_epoch > 1:
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -425,6 +443,7 @@ def load_checkpoint_if_available(
|
|||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
@ -451,7 +470,8 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
@ -465,6 +485,8 @@ def save_checkpoint(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The training model.
|
The training model.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer used in the training.
|
The optimizer used in the training.
|
||||||
sampler:
|
sampler:
|
||||||
@ -478,6 +500,7 @@ def save_checkpoint(
|
|||||||
save_checkpoint_impl(
|
save_checkpoint_impl(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -497,14 +520,14 @@ def save_checkpoint(
|
|||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute transducer loss given the model and its inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
@ -570,7 +593,7 @@ def compute_loss(
|
|||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -604,13 +627,14 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -636,6 +660,8 @@ def train_one_epoch(
|
|||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
scaler:
|
scaler:
|
||||||
The scaler used for mix precision training.
|
The scaler used for mix precision training.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
world_size:
|
world_size:
|
||||||
@ -662,6 +688,7 @@ def train_one_epoch(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
@ -690,6 +717,7 @@ def train_one_epoch(
|
|||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
global_batch_idx=params.batch_idx_train,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -793,7 +821,16 @@ def run(rank, world_size, args):
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
assert params.save_every_n >= params.average_period
|
||||||
|
model_avg: Optional[nn.Module] = None
|
||||||
|
if rank == 0:
|
||||||
|
# model_avg is only used with rank 0
|
||||||
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
|
|
||||||
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
|
checkpoints = load_checkpoint_if_available(
|
||||||
|
params=params, model=model, model_avg=model_avg
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -852,10 +889,10 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
scheduler.step_epoch(epoch)
|
scheduler.step_epoch(epoch - 1)
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
@ -865,6 +902,7 @@ def run(rank, world_size, args):
|
|||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
@ -883,6 +921,7 @@ def run(rank, world_size, args):
|
|||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
@ -898,7 +937,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
@ -121,10 +121,10 @@ def compute_fbank_librispeech(
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if bpe_model:
|
|
||||||
cut_set = filter_cuts(cut_set, sp)
|
|
||||||
|
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
|
if bpe_model:
|
||||||
|
cut_set = filter_cuts(cut_set, sp)
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
|
@ -550,7 +550,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LM=LM,
|
LM=LM,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -561,7 +560,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
632
egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py
Executable file
632
egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py
Executable file
@ -0,0 +1,632 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lexicon.txt"
|
||||||
|
git lfs pull --include "data/L.pt"
|
||||||
|
git lfs pull --include "exp/epoch-11.pt"
|
||||||
|
git lfs pull --include "exp/epoch-10.pt"
|
||||||
|
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export-onnx-zh.py \
|
||||||
|
--lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \
|
||||||
|
--use-averaged-model 1 \
|
||||||
|
--epoch 11 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/exp \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--encoder-dim 512 \
|
||||||
|
--rnn-hidden-size 1024
|
||||||
|
|
||||||
|
It will generate the following files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-11-avg-1.onnx
|
||||||
|
- decoder-epoch-11-avg-1.onnx
|
||||||
|
- joiner-epoch-11-avg-1.onnx
|
||||||
|
- encoder-epoch-11-avg-1.int8.onnx
|
||||||
|
- decoder-epoch-11-avg-1.int8.onnx
|
||||||
|
- joiner-epoch-11-avg-1.int8.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from decoder import Decoder
|
||||||
|
from lstm import RNN
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless5/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for RNN and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: RNN, encoder_proj: nn.Linear):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
An RNN encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of RNN.forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
states:
|
||||||
|
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||||
|
states[0] is the hidden states of all layers,
|
||||||
|
with shape of (num_layers, N, d_model);
|
||||||
|
states[1] is the cell states of all layers,
|
||||||
|
with shape of (num_layers, N, rnn_hidden_size).
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||||
|
- updated states, whose shape is the same as the input states.
|
||||||
|
"""
|
||||||
|
N = x.size(0)
|
||||||
|
T = x.size(1)
|
||||||
|
x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device)
|
||||||
|
encoder_out, _, next_states = self.encoder(x, x_lens, states)
|
||||||
|
|
||||||
|
encoder_out = self.encoder_proj(encoder_out)
|
||||||
|
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
return encoder_out, next_states
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has the following inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- state0, a tensor of shape (num_encoder_layers, batch_size, d_model)
|
||||||
|
- state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size)
|
||||||
|
|
||||||
|
and it has 3 outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||||
|
- new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model)
|
||||||
|
- new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
num_encoder_layers = encoder_model.encoder.num_encoder_layers
|
||||||
|
d_model = encoder_model.encoder.d_model
|
||||||
|
rnn_hidden_size = encoder_model.encoder.rnn_hidden_size
|
||||||
|
|
||||||
|
decode_chunk_len = 4
|
||||||
|
T = 9
|
||||||
|
|
||||||
|
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||||
|
states = encoder_model.encoder.get_init_states()
|
||||||
|
# state0: (num_encoder_layers, batch_size, d_model)
|
||||||
|
# state1: (num_encoder_layers, batch_size, rnn_hidden_size)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, states),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "state0", "state1"],
|
||||||
|
output_names=["encoder_out", "new_state0", "new_state1"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"state0": {1: "N"},
|
||||||
|
"state1": {1: "N"},
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"new_state0": {1: "N"},
|
||||||
|
"new_state1": {1: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "lstm",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||||
|
"T": str(T), # 39
|
||||||
|
"num_encoder_layers": str(num_encoder_layers),
|
||||||
|
"d_model": str(d_model),
|
||||||
|
"rnn_hidden_size": str(rnn_hidden_size),
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: OnnxDecoder,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
params.blank_id = 0
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params, enable_giga=False)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||||
|
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
main()
|
@ -34,11 +34,14 @@ popd
|
|||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp
|
--exp-dir $repo/exp
|
||||||
|
|
||||||
It will generate the following 3 files inside $repo/exp:
|
It will generate the following files inside $repo/exp:
|
||||||
|
|
||||||
- encoder-epoch-99-avg-1.onnx
|
- encoder-epoch-99-avg-1.onnx
|
||||||
- decoder-epoch-99-avg-1.onnx
|
- decoder-epoch-99-avg-1.onnx
|
||||||
- joiner-epoch-99-avg-1.onnx
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
- encoder-epoch-99-avg-1.int8.onnx
|
||||||
|
- decoder-epoch-99-avg-1.int8.onnx
|
||||||
|
- joiner-epoch-99-avg-1.int8.onnx
|
||||||
|
|
||||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||||
use the exported ONNX models.
|
use the exported ONNX models.
|
||||||
@ -55,6 +58,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from lstm import RNN
|
from lstm import RNN
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -586,6 +590,35 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
@ -107,6 +107,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -138,6 +139,8 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -288,7 +291,7 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--left-context",
|
"--left-context",
|
||||||
@ -370,6 +373,14 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
if params.decode_chunk_size > 0:
|
||||||
|
# except the case of using full attention
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -829,11 +829,22 @@ class HypothesisList(object):
|
|||||||
ans.add(hyp) # shallow copy
|
ans.add(hyp) # shallow copy
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def topk(self, k: int) -> "HypothesisList":
|
def topk(self, k: int, length_norm: bool = False) -> "HypothesisList":
|
||||||
"""Return the top-k hypothesis."""
|
"""Return the top-k hypothesis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length_norm:
|
||||||
|
If True, the `log_prob` of a hypothesis is normalized by the
|
||||||
|
number of tokens in it.
|
||||||
|
"""
|
||||||
hyps = list(self._data.items())
|
hyps = list(self._data.items())
|
||||||
|
|
||||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
if length_norm:
|
||||||
|
hyps = sorted(
|
||||||
|
hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True
|
||||||
|
)[:k]
|
||||||
|
else:
|
||||||
|
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||||
|
|
||||||
ans = HypothesisList(dict(hyps))
|
ans = HypothesisList(dict(hyps))
|
||||||
return ans
|
return ans
|
||||||
@ -1852,7 +1863,6 @@ def modified_beam_search_LODR(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
LODR_lm: NgramLm,
|
LODR_lm: NgramLm,
|
||||||
LODR_lm_scale: float,
|
LODR_lm_scale: float,
|
||||||
LM: LmScorer,
|
LM: LmScorer,
|
||||||
@ -1872,8 +1882,6 @@ def modified_beam_search_LODR(
|
|||||||
encoder_out_lens (torch.Tensor):
|
encoder_out_lens (torch.Tensor):
|
||||||
A 1-D tensor of shape (N,), containing the number of
|
A 1-D tensor of shape (N,), containing the number of
|
||||||
valid frames in encoder_out before padding.
|
valid frames in encoder_out before padding.
|
||||||
sp:
|
|
||||||
Sentence piece generator.
|
|
||||||
LODR_lm:
|
LODR_lm:
|
||||||
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
||||||
LODR_lm_scale:
|
LODR_lm_scale:
|
||||||
@ -1901,7 +1909,7 @@ def modified_beam_search_LODR(
|
|||||||
)
|
)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
sos_id = sp.piece_to_id("<sos/eos>")
|
sos_id = getattr(LM, "sos_id", 1)
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
@ -2126,7 +2134,6 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
LM: LmScorer,
|
LM: LmScorer,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
@ -2165,7 +2172,7 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
sos_id = sp.piece_to_id("<sos/eos>")
|
sos_id = getattr(LM, "sos_id", 1)
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
|
@ -375,6 +375,11 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
if chunk_size < 0:
|
||||||
|
# use full attention
|
||||||
|
chunk_size = x.size(0)
|
||||||
|
left_context = -1
|
||||||
|
|
||||||
num_left_chunks = -1
|
num_left_chunks = -1
|
||||||
if left_context >= 0:
|
if left_context >= 0:
|
||||||
assert left_context % chunk_size == 0
|
assert left_context % chunk_size == 0
|
||||||
|
@ -295,7 +295,7 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -378,12 +378,14 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
feature_lens += params.left_context
|
if params.decode_chunk_size > 0:
|
||||||
feature = torch.nn.functional.pad(
|
# except the case of using full attention
|
||||||
feature,
|
feature_lens += params.left_context
|
||||||
pad=(0, 0, 0, params.left_context),
|
feature = torch.nn.functional.pad(
|
||||||
value=LOG_EPS,
|
feature,
|
||||||
)
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -344,7 +344,7 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -508,12 +508,14 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
feature_lens += params.left_context
|
if params.decode_chunk_size > 0:
|
||||||
feature = torch.nn.functional.pad(
|
# except the case of using full attention
|
||||||
feature,
|
feature_lens += params.left_context
|
||||||
pad=(0, 0, 0, params.left_context),
|
feature = torch.nn.functional.pad(
|
||||||
value=LOG_EPS,
|
feature,
|
||||||
)
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
@ -673,7 +675,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LM=LM,
|
LM=LM,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -684,7 +685,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
@ -54,6 +54,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -273,6 +274,16 @@ def export_encoder_model_onnx(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "conformer",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "stateless3",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
def export_decoder_model_onnx(
|
def export_decoder_model_onnx(
|
||||||
decoder_model: OnnxDecoder,
|
decoder_model: OnnxDecoder,
|
||||||
@ -490,6 +501,35 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
@ -403,9 +403,8 @@ def main():
|
|||||||
text += symbol_table[i]
|
text += symbol_table[i]
|
||||||
return text.replace("▁", " ").strip()
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
context_size = model.context_size
|
|
||||||
for filename, hyp in zip(args.sound_files, hyps):
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
words = token_ids_to_words(hyp[context_size:])
|
words = token_ids_to_words(hyp)
|
||||||
s += f"{filename}:\n{words}\n"
|
s += f"{filename}:\n{words}\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
@ -326,14 +326,14 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--left-context",
|
"--left-context",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="left context can be seen during decoding (in frames after subsampling)", # noqa
|
help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -409,12 +409,14 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
feature_lens += params.left_context
|
if params.decode_chunk_size > 0:
|
||||||
feature = torch.nn.functional.pad(
|
# except the case of using full attention
|
||||||
feature,
|
feature_lens += params.left_context
|
||||||
pad=(0, 0, 0, params.left_context),
|
feature = torch.nn.functional.pad(
|
||||||
value=LOG_EPS,
|
feature,
|
||||||
)
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -291,7 +291,7 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -470,12 +470,14 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
feature_lens += params.left_context
|
if params.decode_chunk_size > 0:
|
||||||
feature = torch.nn.functional.pad(
|
# except the case of using full attention
|
||||||
feature,
|
feature_lens += params.left_context
|
||||||
pad=(0, 0, 0, params.left_context),
|
feature = torch.nn.functional.pad(
|
||||||
value=LOG_EPS,
|
feature,
|
||||||
)
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
@ -584,7 +586,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LM=LM,
|
LM=LM,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -595,7 +596,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
@ -296,6 +296,16 @@ def export_encoder_model_onnx(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "conformer",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "stateless5",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
def export_decoder_model_onnx(
|
def export_decoder_model_onnx(
|
||||||
decoder_model: OnnxDecoder,
|
decoder_model: OnnxDecoder,
|
||||||
|
206
egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py
Normal file
206
egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||||
|
|
||||||
|
# The force alignment problem can be formulated as finding
|
||||||
|
# a path in a rectangular lattice, where the path starts
|
||||||
|
# from the lower left corner and ends at the upper right
|
||||||
|
# corner. The horizontal axis of the lattice is `t` (representing
|
||||||
|
# acoustic frame indexes) and the vertical axis is `u` (representing
|
||||||
|
# BPE tokens of the transcript).
|
||||||
|
#
|
||||||
|
# The notations `t` and `u` are from the paper
|
||||||
|
# https://arxiv.org/pdf/1211.3711.pdf
|
||||||
|
#
|
||||||
|
# Beam search is used to find the path with the highest log probabilities.
|
||||||
|
#
|
||||||
|
# It assumes the maximum number of symbols that can be
|
||||||
|
# emitted per frame is 1.
|
||||||
|
|
||||||
|
|
||||||
|
def batch_force_alignment(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
ys_list: List[List[int]],
|
||||||
|
beam_size: int = 4,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Compute the force alignment of a batch of utterances given their transcripts
|
||||||
|
in BPE tokens and the corresponding acoustic output from the encoder.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
This function is modified from `modified_beam_search` in beam_search.py.
|
||||||
|
We assume that the maximum number of sybmols per frame is 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
|
ys_list:
|
||||||
|
A list of BPE token IDs list. We require that for each utterance i,
|
||||||
|
len(ys_list[i]) <= encoder_out_lens[i].
|
||||||
|
beam_size:
|
||||||
|
Size of the beam used in beam search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of frame indexes list for each utterance i,
|
||||||
|
where len(ans[i]) == len(ys_list[i]).
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.ndim
|
||||||
|
assert encoder_out.size(0) == len(ys_list), (encoder_out.size(0), len(ys_list))
|
||||||
|
assert encoder_out.size(0) > 0, encoder_out.size(0)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
sorted_indices = packed_encoder_out.sorted_indices.tolist()
|
||||||
|
encoder_out_lens = encoder_out_lens.tolist()
|
||||||
|
ys_lens = [len(ys) for ys in ys_list]
|
||||||
|
sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices]
|
||||||
|
sorted_ys_lens = [ys_lens[i] for i in sorted_indices]
|
||||||
|
sorted_ys_list = [ys_list[i] for i in sorted_indices]
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
timestamp=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size]
|
||||||
|
sorted_ys_lens = sorted_ys_lens[:batch_size]
|
||||||
|
|
||||||
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
) # (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||||
|
|
||||||
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
|
# as index, so we use `to(torch.int64)` below.
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out, project_input=False
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(
|
||||||
|
shape=log_probs_shape, value=log_probs.reshape(-1)
|
||||||
|
) # [batch][num_hyps*vocab_size]
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
for h, hyp in enumerate(A[i]):
|
||||||
|
pos_u = len(hyp.timestamp)
|
||||||
|
idx_offset = h * vocab_size
|
||||||
|
if (sorted_encoder_out_lens[i] - 1 - t) >= (sorted_ys_lens[i] - pos_u):
|
||||||
|
# emit blank token
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
log_prob=ragged_log_probs[i][idx_offset + blank_id],
|
||||||
|
ys=hyp.ys[:],
|
||||||
|
timestamp=hyp.timestamp[:],
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
if pos_u < sorted_ys_lens[i]:
|
||||||
|
# emit non-blank token
|
||||||
|
new_token = sorted_ys_list[i][pos_u]
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
log_prob=ragged_log_probs[i][idx_offset + new_token],
|
||||||
|
ys=hyp.ys + [new_token],
|
||||||
|
timestamp=hyp.timestamp + [t],
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
if len(B[i]) > beam_size:
|
||||||
|
B[i] = B[i].topk(beam_size, length_norm=True)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
|
sorted_hyps = [b.get_most_probable() for b in B]
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
hyps = [sorted_hyps[i] for i in unsorted_indices]
|
||||||
|
ans = []
|
||||||
|
for i, hyp in enumerate(hyps):
|
||||||
|
assert hyp.ys[context_size:] == ys_list[i], (hyp.ys[context_size:], ys_list[i])
|
||||||
|
ans.append(hyp.timestamp)
|
||||||
|
|
||||||
|
return ans
|
345
egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py
Executable file
345
egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py
Executable file
@ -0,0 +1,345 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
The script gets forced-alignments based on the modified_beam_search decoding method.
|
||||||
|
Both token-level alignments and word-level alignments are saved to the new cuts manifests.
|
||||||
|
|
||||||
|
It loads a checkpoint and uses it to get the forced-alignments.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/compute_ali.py \
|
||||||
|
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--dataset test-clean \
|
||||||
|
--max-duration 300 \
|
||||||
|
--beam-size 4 \
|
||||||
|
--cuts-out-dir data/fbank_ali_beam_search
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from alignment import batch_force_alignment
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
|
||||||
|
from lhotse import CutSet
|
||||||
|
from lhotse.serialization import SequentialJsonlWriter
|
||||||
|
from lhotse.supervision import AlignmentItem
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="""The name of the dataset to compute alignments for.
|
||||||
|
Possible values are:
|
||||||
|
- test-clean
|
||||||
|
- test-other
|
||||||
|
- train-clean-100
|
||||||
|
- train-clean-360
|
||||||
|
- train-other-500
|
||||||
|
- dev-clean
|
||||||
|
- dev-other
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuts-out-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank_ali_beam_search",
|
||||||
|
help="The dir to save the new cuts manifests with alignments",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def align_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
batch: dict,
|
||||||
|
) -> Tuple[List[List[str]], List[List[str]], List[List[float]], List[List[float]]]:
|
||||||
|
"""Get forced-alignments for one batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
token_list:
|
||||||
|
A list of token list.
|
||||||
|
word_list:
|
||||||
|
A list of word list.
|
||||||
|
token_time_list:
|
||||||
|
A list of timestamps list for tokens.
|
||||||
|
word_time_list.
|
||||||
|
A list of timestamps list for words.
|
||||||
|
|
||||||
|
where len(token_list) == len(word_list) == len(token_time_list) == len(word_time_list),
|
||||||
|
len(token_list[i]) == len(token_time_list[i]),
|
||||||
|
and len(word_list[i]) == len(word_time_list[i])
|
||||||
|
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
|
texts = supervisions["text"]
|
||||||
|
ys_list: List[List[int]] = sp.encode(texts, out_type=int)
|
||||||
|
|
||||||
|
frame_indexes = batch_force_alignment(
|
||||||
|
model, encoder_out, encoder_out_lens, ys_list, params.beam_size
|
||||||
|
)
|
||||||
|
|
||||||
|
token_list = []
|
||||||
|
word_list = []
|
||||||
|
token_time_list = []
|
||||||
|
word_time_list = []
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
tokens = sp.id_to_piece(ys_list[i])
|
||||||
|
words = texts[i].split()
|
||||||
|
token_time = convert_timestamp(
|
||||||
|
frame_indexes[i], params.subsampling_factor, params.frame_shift_ms
|
||||||
|
)
|
||||||
|
word_time = parse_timestamp(tokens, token_time)
|
||||||
|
assert len(word_time) == len(words), (len(word_time), len(words))
|
||||||
|
|
||||||
|
token_list.append(tokens)
|
||||||
|
word_list.append(words)
|
||||||
|
token_time_list.append(token_time)
|
||||||
|
word_time_list.append(word_time)
|
||||||
|
|
||||||
|
return token_list, word_list, token_time_list, word_time_list
|
||||||
|
|
||||||
|
|
||||||
|
def align_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
writer: SequentialJsonlWriter,
|
||||||
|
) -> None:
|
||||||
|
"""Get forced-alignments for the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
writer:
|
||||||
|
Writer to save the cuts with alignments.
|
||||||
|
"""
|
||||||
|
log_interval = 20
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
token_list, word_list, token_time_list, word_time_list = align_one_batch(
|
||||||
|
params=params, model=model, sp=sp, batch=batch
|
||||||
|
)
|
||||||
|
|
||||||
|
cut_list = batch["supervisions"]["cut"]
|
||||||
|
for cut, token, word, token_time, word_time in zip(
|
||||||
|
cut_list, token_list, word_list, token_time_list, word_time_list
|
||||||
|
):
|
||||||
|
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
|
||||||
|
token_ali = [
|
||||||
|
AlignmentItem(
|
||||||
|
symbol=token[i],
|
||||||
|
start=round(token_time[i], ndigits=3),
|
||||||
|
duration=None,
|
||||||
|
)
|
||||||
|
for i in range(len(token))
|
||||||
|
]
|
||||||
|
word_ali = [
|
||||||
|
AlignmentItem(
|
||||||
|
symbol=word[i], start=round(word_time[i], ndigits=3), duration=None
|
||||||
|
)
|
||||||
|
for i in range(len(word))
|
||||||
|
]
|
||||||
|
cut.supervisions[0].alignment = {"word": word_ali, "token": token_ali}
|
||||||
|
writer.write(cut, flush=True)
|
||||||
|
|
||||||
|
num_cuts += len(cut_list)
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
if params.dataset == "test-clean":
|
||||||
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
|
elif params.dataset == "test-other":
|
||||||
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
elif params.dataset == "train-clean-100":
|
||||||
|
train_clean_100_cuts = librispeech.train_clean_100_cuts()
|
||||||
|
dl = librispeech.train_dataloaders(train_clean_100_cuts)
|
||||||
|
elif params.dataset == "train-clean-360":
|
||||||
|
train_clean_360_cuts = librispeech.train_clean_360_cuts()
|
||||||
|
dl = librispeech.train_dataloaders(train_clean_360_cuts)
|
||||||
|
elif params.dataset == "train-other-500":
|
||||||
|
train_other_500_cuts = librispeech.train_other_500_cuts()
|
||||||
|
dl = librispeech.train_dataloaders(train_other_500_cuts)
|
||||||
|
elif params.dataset == "dev-clean":
|
||||||
|
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||||
|
dl = librispeech.valid_dataloaders(dev_clean_cuts)
|
||||||
|
else:
|
||||||
|
assert params.dataset == "dev-other", f"{params.dataset}"
|
||||||
|
dev_other_cuts = librispeech.dev_other_cuts()
|
||||||
|
dl = librispeech.valid_dataloaders(dev_other_cuts)
|
||||||
|
|
||||||
|
cuts_out_dir = Path(params.cuts_out_dir)
|
||||||
|
cuts_out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz"
|
||||||
|
|
||||||
|
with CutSet.open_writer(cuts_out_path) as writer:
|
||||||
|
align_dataset(dl=dl, params=params, model=model, sp=sp, writer=writer)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"For dataset {params.dataset}, the cut manifest with framewise token alignments "
|
||||||
|
f"and word alignments are saved to {cuts_out_path}"
|
||||||
|
)
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -343,29 +343,6 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--simulate-streaming",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
|
||||||
test a streaming model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decode-chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--left-context",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-shallow-fusion",
|
"--use-shallow-fusion",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -474,22 +451,7 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
|
||||||
x=feature,
|
|
||||||
x_lens=feature_lens,
|
|
||||||
chunk_size=params.decode_chunk_size,
|
|
||||||
left_context=params.left_context,
|
|
||||||
simulate_streaming=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
@ -571,7 +533,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LM=LM,
|
LM=LM,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -582,7 +543,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
sp=sp,
|
|
||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
@ -782,10 +742,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
@ -834,11 +790,6 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
assert (
|
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -55,6 +55,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer
|
||||||
@ -291,6 +292,16 @@ def export_encoder_model_onnx(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "zipformer",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "stateless7",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
def export_decoder_model_onnx(
|
def export_decoder_model_onnx(
|
||||||
decoder_model: OnnxDecoder,
|
decoder_model: OnnxDecoder,
|
||||||
@ -553,6 +564,35 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
130
egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py
Executable file
130
egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py
Executable file
@ -0,0 +1,130 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script compares the word-level alignments generated based on modified_beam_search decoding
|
||||||
|
(in ./pruned_transducer_stateless7/compute_ali.py) to the reference alignments generated
|
||||||
|
by torchaudio framework (in ./add_alignments.sh).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/compute_ali.py \
|
||||||
|
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--dataset test-clean \
|
||||||
|
--max-duration 300 \
|
||||||
|
--beam-size 4 \
|
||||||
|
--cuts-out-dir data/fbank_ali_beam_search
|
||||||
|
|
||||||
|
And the you can run:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/test_compute_ali.py \
|
||||||
|
--cuts-out-dir ./data/fbank_ali_test \
|
||||||
|
--cuts-ref-dir ./data/fbank_ali_torch \
|
||||||
|
--dataset train-clean-100
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuts-out-dir",
|
||||||
|
type=Path,
|
||||||
|
default="./data/fbank_ali",
|
||||||
|
help="The dir that saves the generated cuts manifests with alignments",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuts-ref-dir",
|
||||||
|
type=Path,
|
||||||
|
default="./data/fbank_ali_torch",
|
||||||
|
help="The dir that saves the reference cuts manifests with alignments",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="""The name of the dataset:
|
||||||
|
Possible values are:
|
||||||
|
- test-clean
|
||||||
|
- test-other
|
||||||
|
- train-clean-100
|
||||||
|
- train-clean-360
|
||||||
|
- train-other-500
|
||||||
|
- dev-clean
|
||||||
|
- dev-other
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
|
||||||
|
cuts_out_jsonl = args.cuts_out_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
|
||||||
|
cuts_ref_jsonl = args.cuts_ref_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
|
||||||
|
|
||||||
|
logging.info(f"Loading {cuts_out_jsonl} and {cuts_ref_jsonl}")
|
||||||
|
cuts_out = load_manifest(cuts_out_jsonl)
|
||||||
|
cuts_ref = load_manifest(cuts_ref_jsonl)
|
||||||
|
cuts_ref = cuts_ref.sort_like(cuts_out)
|
||||||
|
|
||||||
|
all_time_diffs = []
|
||||||
|
for cut_out, cut_ref in zip(cuts_out, cuts_ref):
|
||||||
|
time_out = [
|
||||||
|
ali.start
|
||||||
|
for ali in cut_out.supervisions[0].alignment["word"]
|
||||||
|
if ali.symbol != ""
|
||||||
|
]
|
||||||
|
time_ref = [
|
||||||
|
ali.start
|
||||||
|
for ali in cut_ref.supervisions[0].alignment["word"]
|
||||||
|
if ali.symbol != ""
|
||||||
|
]
|
||||||
|
assert len(time_out) == len(time_ref), (len(time_out), len(time_ref))
|
||||||
|
diff = [
|
||||||
|
round(abs(out - ref), ndigits=3) for out, ref in zip(time_out, time_ref)
|
||||||
|
]
|
||||||
|
all_time_diffs += diff
|
||||||
|
|
||||||
|
all_time_diffs = torch.tensor(all_time_diffs)
|
||||||
|
logging.info(
|
||||||
|
f"For the word-level alignments abs difference on dataset {args.dataset}, "
|
||||||
|
f"mean: {'%.2f' % all_time_diffs.mean()}s, std: {'%.2f' % all_time_diffs.std()}s"
|
||||||
|
)
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
678
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py
Executable file
678
egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py
Executable file
@ -0,0 +1,678 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_char_bpe/L.pt"
|
||||||
|
git lfs pull --include "data/lang_char_bpe/Linv.pt"
|
||||||
|
git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
|
||||||
|
--lang-dir $repo/data/lang_char_bpe \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp/ \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--num-encoder-layers "2,4,3,2,4" \
|
||||||
|
--feedforward-dims "1024,1024,1536,1536,1024" \
|
||||||
|
--nhead "8,8,8,8,8" \
|
||||||
|
--encoder-dims "384,384,384,384,384" \
|
||||||
|
--attention-dims "192,192,192,192,192" \
|
||||||
|
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||||
|
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||||
|
--cnn-module-kernels "31,31,31,31,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
|
||||||
|
It will generate the following 3 files in $repo/exp
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py for how to use the exported models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from torch import Tensor
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from zipformer import Zipformer
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless7_streaming/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
A Zipformer encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]:
|
||||||
|
"""Please see the help information of Zipformer.streaming_forward"""
|
||||||
|
N = x.size(0)
|
||||||
|
T = x.size(1)
|
||||||
|
x_lens = torch.tensor([T] * N, device=x.device)
|
||||||
|
|
||||||
|
output, _, new_states = self.encoder.streaming_forward(
|
||||||
|
x=x,
|
||||||
|
x_lens=x_lens,
|
||||||
|
states=states,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.encoder_proj(output)
|
||||||
|
# Now output is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
return output, new_states
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Onnx model inputs:
|
||||||
|
- 0: src
|
||||||
|
- many state tensors (the exact number depending on the actual model)
|
||||||
|
|
||||||
|
Onnx model outputs:
|
||||||
|
- 0: output, its shape is (N, T, joiner_dim)
|
||||||
|
- many state tensors (the exact number depending on the actual model)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The model to be exported
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
encoder_model.encoder.__class__.forward = (
|
||||||
|
encoder_model.encoder.__class__.streaming_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2
|
||||||
|
pad_length = 7
|
||||||
|
T = decode_chunk_len + pad_length
|
||||||
|
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||||
|
logging.info(f"pad_length: {pad_length}")
|
||||||
|
logging.info(f"T: {T}")
|
||||||
|
|
||||||
|
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||||
|
|
||||||
|
init_state = encoder_model.encoder.get_init_state()
|
||||||
|
|
||||||
|
num_encoders = encoder_model.encoder.num_encoders
|
||||||
|
logging.info(f"num_encoders: {num_encoders}")
|
||||||
|
logging.info(f"len(init_state): {len(init_state)}")
|
||||||
|
|
||||||
|
inputs = {}
|
||||||
|
input_names = ["x"]
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
output_names = ["encoder_out"]
|
||||||
|
|
||||||
|
def build_inputs_outputs(tensors, name, N):
|
||||||
|
for i, s in enumerate(tensors):
|
||||||
|
logging.info(f"{name}_{i}.shape: {s.shape}")
|
||||||
|
inputs[f"{name}_{i}"] = {N: "N"}
|
||||||
|
outputs[f"new_{name}_{i}"] = {N: "N"}
|
||||||
|
input_names.append(f"{name}_{i}")
|
||||||
|
output_names.append(f"new_{name}_{i}")
|
||||||
|
|
||||||
|
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||||
|
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims))
|
||||||
|
attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims))
|
||||||
|
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels))
|
||||||
|
ds = encoder_model.encoder.zipformer_downsampling_factors
|
||||||
|
left_context_len = encoder_model.encoder.left_context_len
|
||||||
|
left_context_len = [left_context_len // k for k in ds]
|
||||||
|
left_context_len = ",".join(map(str, left_context_len))
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "zipformer",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||||
|
"T": str(T), # 39
|
||||||
|
"num_encoder_layers": num_encoder_layers,
|
||||||
|
"encoder_dims": encoder_dims,
|
||||||
|
"attention_dims": attention_dims,
|
||||||
|
"cnn_module_kernels": cnn_module_kernels,
|
||||||
|
"left_context_len": left_context_len,
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
# (num_encoder_layers, 1)
|
||||||
|
cached_len = init_state[num_encoders * 0 : num_encoders * 1]
|
||||||
|
|
||||||
|
# (num_encoder_layers, 1, encoder_dim)
|
||||||
|
cached_avg = init_state[num_encoders * 1 : num_encoders * 2]
|
||||||
|
|
||||||
|
# (num_encoder_layers, left_context_len, 1, attention_dim)
|
||||||
|
cached_key = init_state[num_encoders * 2 : num_encoders * 3]
|
||||||
|
|
||||||
|
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||||
|
cached_val = init_state[num_encoders * 3 : num_encoders * 4]
|
||||||
|
|
||||||
|
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||||
|
cached_val2 = init_state[num_encoders * 4 : num_encoders * 5]
|
||||||
|
|
||||||
|
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||||
|
cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6]
|
||||||
|
|
||||||
|
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||||
|
cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7]
|
||||||
|
|
||||||
|
build_inputs_outputs(cached_len, "cached_len", 1)
|
||||||
|
build_inputs_outputs(cached_avg, "cached_avg", 1)
|
||||||
|
build_inputs_outputs(cached_key, "cached_key", 2)
|
||||||
|
build_inputs_outputs(cached_val, "cached_val", 2)
|
||||||
|
build_inputs_outputs(cached_val2, "cached_val2", 2)
|
||||||
|
build_inputs_outputs(cached_conv1, "cached_conv1", 1)
|
||||||
|
build_inputs_outputs(cached_conv2, "cached_conv2", 1)
|
||||||
|
|
||||||
|
logging.info(inputs)
|
||||||
|
logging.info(outputs)
|
||||||
|
logging.info(input_names)
|
||||||
|
logging.info(output_names)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, init_state),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N"},
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
**inputs,
|
||||||
|
**outputs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
params.blank_id = 0
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
if params.use_averaged_model:
|
||||||
|
suffix += "-with-averaged-model"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -53,6 +53,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -634,6 +635,35 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -301,29 +301,6 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--simulate-streaming",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
|
||||||
test a streaming model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decode-chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--left-context",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -378,22 +355,7 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
|
||||||
x=feature,
|
|
||||||
x_lens=feature_lens,
|
|
||||||
chunk_size=params.decode_chunk_size,
|
|
||||||
left_context=params.left_context,
|
|
||||||
simulate_streaming=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
@ -651,10 +613,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
@ -690,11 +648,6 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
assert (
|
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -358,6 +358,11 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
if chunk_size < 0:
|
||||||
|
# use full attention
|
||||||
|
chunk_size = x.size(0)
|
||||||
|
left_context = -1
|
||||||
|
|
||||||
num_left_chunks = -1
|
num_left_chunks = -1
|
||||||
if left_context >= 0:
|
if left_context >= 0:
|
||||||
assert left_context % chunk_size == 0
|
assert left_context % chunk_size == 0
|
||||||
|
82
egs/wenetspeech/ASR/finetune.sh
Executable file
82
egs/wenetspeech/ASR/finetune.sh
Executable file
@ -0,0 +1,82 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
# This is an example script for fine-tuning. Here, we fine-tune a model trained
|
||||||
|
# on WenetSpeech on Aishell. The model used for fine-tuning is
|
||||||
|
# pruned_transducer_stateless2 (zipformer). If you want to fine-tune model
|
||||||
|
# from another recipe, you can adapt ./pruned_transducer_stateless2/finetune.py
|
||||||
|
# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues.
|
||||||
|
|
||||||
|
# We assume that you have already prepared the Aishell manfiest&features under ./data.
|
||||||
|
# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/prepare.sh.
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
|
log "Stage -1: Download Pre-trained model"
|
||||||
|
|
||||||
|
# clone from huggingface
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Stage 0: Start fine-tuning"
|
||||||
|
|
||||||
|
# The following configuration of lr schedule should work well
|
||||||
|
# You may also tune the following parameters to adjust learning rate schedule
|
||||||
|
initial_lr=0.0001
|
||||||
|
lr_epochs=100
|
||||||
|
lr_batches=100000
|
||||||
|
|
||||||
|
# We recommend to start from an averaged model
|
||||||
|
finetune_ckpt=icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/pretrained_epoch_10_avg_2.pt
|
||||||
|
lang_dir=icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/finetune.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--master-port 18180 \
|
||||||
|
--num-epochs 15 \
|
||||||
|
--context-size 2 \
|
||||||
|
--exp-dir pruned_transducer_stateless2/exp_aishell_finetune \
|
||||||
|
--initial-lr $initial_lr \
|
||||||
|
--lr-epochs $lr_epochs \
|
||||||
|
--lr-batches $lr_batches \
|
||||||
|
--lang-dir $lang_dir \
|
||||||
|
--do-finetune True \
|
||||||
|
--finetune-ckpt $finetune_ckpt \
|
||||||
|
--max-duration 200
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Decoding"
|
||||||
|
|
||||||
|
epoch=4
|
||||||
|
avg=4
|
||||||
|
|
||||||
|
for m in greedy_search modified_beam_search; do
|
||||||
|
python pruned_transducer_stateless2/decode_aishell.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--context-size 2 \
|
||||||
|
--beam-size 4 \
|
||||||
|
--exp-dir pruned_transducer_stateless2/exp_aishell_finetune \
|
||||||
|
--max-duration 400 \
|
||||||
|
--decoding-method $m
|
||||||
|
done
|
||||||
|
fi
|
@ -40,8 +40,8 @@ from tqdm import tqdm
|
|||||||
# and 'data()' is only supported in static graph mode. So if you
|
# and 'data()' is only supported in static graph mode. So if you
|
||||||
# want to use this api, should call 'paddle.enable_static()' before
|
# want to use this api, should call 'paddle.enable_static()' before
|
||||||
# this api to enter static graph mode.
|
# this api to enter static graph mode.
|
||||||
paddle.enable_static()
|
# paddle.enable_static()
|
||||||
paddle.disable_signal_handler()
|
# paddle.disable_signal_handler()
|
||||||
jieba.enable_paddle()
|
jieba.enable_paddle()
|
||||||
|
|
||||||
|
|
||||||
|
@ -261,3 +261,107 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
|||||||
log "Stage 18: Compile LG"
|
log "Stage 18: Compile LG"
|
||||||
python ./local/compile_lg.py --lang-dir $lang_char_dir
|
python ./local/compile_lg.py --lang-dir $lang_char_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# prepare RNNLM data
|
||||||
|
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||||
|
log "Stage 19: Prepare LM training data"
|
||||||
|
|
||||||
|
log "Processing char based data"
|
||||||
|
text_out_dir=data/lm_char
|
||||||
|
|
||||||
|
mkdir -p $text_out_dir
|
||||||
|
|
||||||
|
log "Genearating training text data"
|
||||||
|
|
||||||
|
if [ ! -f $text_out_dir/lm_data.pt ]; then
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $lang_char_dir/text_words_segmentation \
|
||||||
|
--lm-archive $text_out_dir/lm_data.pt
|
||||||
|
fi
|
||||||
|
|
||||||
|
log "Generating DEV text data"
|
||||||
|
# prepare validation text data
|
||||||
|
if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then
|
||||||
|
valid_text=${text_out_dir}/
|
||||||
|
|
||||||
|
gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \
|
||||||
|
| jq '.text' | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $text_out_dir/valid_text
|
||||||
|
|
||||||
|
python3 ./local/text2segments.py \
|
||||||
|
--num-process $nj \
|
||||||
|
--input-file $text_out_dir/valid_text \
|
||||||
|
--output-file $text_out_dir/valid_text_words_segmentation
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $text_out_dir/valid_text_words_segmentation \
|
||||||
|
--lm-archive $text_out_dir/lm_data_valid.pt
|
||||||
|
|
||||||
|
# prepare TEST text data
|
||||||
|
if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then
|
||||||
|
log "Prepare text for test set."
|
||||||
|
for test_set in TEST_MEETING TEST_NET; do
|
||||||
|
gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \
|
||||||
|
| jq '.text' | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text
|
||||||
|
|
||||||
|
python3 ./local/text2segments.py \
|
||||||
|
--num-process $nj \
|
||||||
|
--input-file $text_out_dir/${test_set}_text \
|
||||||
|
--output-file $text_out_dir/${test_set}_text_words_segmentation
|
||||||
|
done
|
||||||
|
|
||||||
|
cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $text_out_dir/test_text_words_segmentation \
|
||||||
|
--lm-archive $text_out_dir/lm_data_test.pt
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
# sort RNNLM data
|
||||||
|
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
|
text_out_dir=data/lm_char
|
||||||
|
|
||||||
|
log "Sort lm data"
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $text_out_dir/lm_data.pt \
|
||||||
|
--out-lm-data $text_out_dir/sorted_lm_data.pt \
|
||||||
|
--out-statistics $text_out_dir/statistics.txt
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $text_out_dir/lm_data_valid.pt \
|
||||||
|
--out-lm-data $text_out_dir/sorted_lm_data-valid.pt \
|
||||||
|
--out-statistics $text_out_dir/statistics-valid.txt
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $text_out_dir/lm_data_test.pt \
|
||||||
|
--out-lm-data $text_out_dir/sorted_lm_data-test.pt \
|
||||||
|
--out-statistics $text_out_dir/statistics-test.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||||
|
log "Stage 21: Train RNN LM model"
|
||||||
|
python ../../../icefall/rnn_lm/train.py \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--embedding-dim 2048 \
|
||||||
|
--hidden-dim 2048 \
|
||||||
|
--num-layers 2 \
|
||||||
|
--batch-size 400 \
|
||||||
|
--exp-dir rnnlm_char/exp \
|
||||||
|
--lm-data data/lm_char/sorted_lm_data.pt \
|
||||||
|
--lm-data-valid data/lm_char/sorted_lm_data-valid.pt \
|
||||||
|
--vocab-size 4336 \
|
||||||
|
--master-port 12340
|
||||||
|
fi
|
1
egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py
Symbolic link
1
egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
|
547
egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py
Executable file
547
egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py
Executable file
@ -0,0 +1,547 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) greedy search
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 84 \
|
||||||
|
--avg 25 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
(2) beam search (not recommended)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 84 \
|
||||||
|
--avg 25 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 84 \
|
||||||
|
--avg 25 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 84 \
|
||||||
|
--avg 25 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from aishell import AishellAsrDataModule
|
||||||
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from finetune import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless2/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding_method is greedy_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
token_table: k2.SymbolTable,
|
||||||
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if greedy_search is used, it would be "greedy_search"
|
||||||
|
If beam search with a beam size of 7 is used, it would be
|
||||||
|
"beam_7"
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
It maps token ID to a string.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hyp_tokens = []
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
for i in range(batch_size):
|
||||||
|
# fmt: off
|
||||||
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
|
# fmt: on
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "beam_search":
|
||||||
|
hyp = beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
hyp_tokens.append(hyp)
|
||||||
|
|
||||||
|
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
return {"greedy_search": hyps}
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
token_table: k2.SymbolTable,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
It maps a token ID to a string.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
log_interval = 50
|
||||||
|
else:
|
||||||
|
log_interval = 20
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
token_table=token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
# we compute CER for aishell dataset.
|
||||||
|
results_char = []
|
||||||
|
for res in results:
|
||||||
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
params.blank_id = 0
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
aishell = AishellAsrDataModule(args)
|
||||||
|
test_cuts = aishell.test_cuts()
|
||||||
|
dev_cuts = aishell.valid_cuts()
|
||||||
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
dev_dl = aishell.test_dataloaders(dev_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test", "dev"]
|
||||||
|
test_dls = [test_dl, dev_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
token_table=lexicon.token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1050
egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py
Executable file
1050
egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -2,6 +2,7 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -91,6 +92,22 @@ When training with the L subset, the streaming usage:
|
|||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
--decode-chunk-size 16 \
|
--decode-chunk-size 16 \
|
||||||
--left-context 64
|
--left-context 64
|
||||||
|
|
||||||
|
(4) modified beam search with RNNLM shallow fusion
|
||||||
|
./pruned_transducer_stateless5/decode.py \
|
||||||
|
--epoch 35 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||||
|
--beam-size 4 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-scale 0.3 \
|
||||||
|
--lm-exp-dir /path/to/LM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -111,9 +128,12 @@ from beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
|
modified_beam_search_lm_shallow_fusion,
|
||||||
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -224,6 +244,16 @@ def get_parser():
|
|||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
@ -277,6 +307,50 @@ def get_parser():
|
|||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-shallow-fusion",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Use neural network LM for shallow fusion.
|
||||||
|
If you want to use LODR, you will also need to set this to true
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="""The scale of the neural network LM
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-ngram",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Token Ngram used for rescoring.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring, or LODR
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""ID of the backoff symbol.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -288,6 +362,9 @@ def decode_one_batch(
|
|||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -374,6 +451,28 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||||
|
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "modified_beam_search_LODR":
|
||||||
|
hyp_tokens = modified_beam_search_LODR(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LODR_lm=ngram_lm,
|
||||||
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -419,6 +518,9 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -432,6 +534,8 @@ def decode_dataset(
|
|||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search.
|
||||||
|
LM:
|
||||||
|
A neural network LM, used during shallow fusion
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -449,7 +553,7 @@ def decode_dataset(
|
|||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 100
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -463,6 +567,9 @@ def decode_dataset(
|
|||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -524,6 +631,7 @@ def save_results(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
WenetSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -535,6 +643,8 @@ def main():
|
|||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
|
"modified_beam_search_LODR",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -549,6 +659,22 @@ def main():
|
|||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if "ngram" in params.decoding_method:
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
if params.lm_type == "rnn":
|
||||||
|
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
|
||||||
|
elif params.lm_type == "transformer":
|
||||||
|
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
|
if "LODR" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -558,6 +684,7 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
@ -652,6 +779,37 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
# only load N-gram LM when needed
|
||||||
|
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||||
|
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||||
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
|
ngram_lm = NgramLm(
|
||||||
|
str(params.lang_dir / lm_filename),
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
ngram_lm_scale = params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
ngram_lm = None
|
||||||
|
ngram_lm_scale = None
|
||||||
|
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
# only load the neural network LM if doing shallow fusion
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
LM = LmScorer(
|
||||||
|
lm_type=params.lm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
LM.to(device)
|
||||||
|
LM.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in LM.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
else:
|
||||||
|
LM = None
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
@ -684,6 +842,9 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
save_results(
|
save_results(
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -50,7 +50,7 @@ class LmScorer(torch.nn.Module):
|
|||||||
def add_arguments(cls, parser):
|
def add_arguments(cls, parser):
|
||||||
# LM general arguments
|
# LM general arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size",
|
"--lm-vocab-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=500,
|
default=500,
|
||||||
)
|
)
|
||||||
|
0
icefall/rnn_lm/__init__.py
Normal file
0
icefall/rnn_lm/__init__.py
Normal file
@ -33,7 +33,7 @@ import torch
|
|||||||
from dataset import get_dataloader
|
from dataset import get_dataloader
|
||||||
from model import RnnLmModel
|
from model import RnnLmModel
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -49,6 +49,7 @@ def get_parser():
|
|||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
@ -58,6 +59,16 @@ def get_parser():
|
|||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -154,7 +165,14 @@ def main():
|
|||||||
|
|
||||||
params = AttributeDict(vars(args))
|
params = AttributeDict(vars(args))
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log-ppl/")
|
if params.iter > 0:
|
||||||
|
setup_logger(
|
||||||
|
f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
setup_logger(
|
||||||
|
f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
)
|
||||||
logging.info("Computing perplexity started")
|
logging.info("Computing perplexity started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -173,19 +191,39 @@ def main():
|
|||||||
tie_weights=params.tie_weights,
|
tie_weights=params.tie_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.iter > 0:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
for i in range(start, params.epoch + 1):
|
for i in range(start, params.epoch + 1):
|
||||||
if start >= 0:
|
if i >= 0:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
num_param_requires_grad = sum(
|
num_param_requires_grad = sum(
|
||||||
|
@ -25,7 +25,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from model import RnnLmModel
|
from model import RnnLmModel
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import AttributeDict, load_averaged_model, str2bool
|
from icefall.utils import AttributeDict, load_averaged_model, str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -51,6 +51,16 @@ def get_parser():
|
|||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size",
|
"--vocab-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -133,11 +143,36 @@ def main():
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
model = load_averaged_model(
|
start = params.epoch - params.avg + 1
|
||||||
params.exp_dir, model, params.epoch, params.avg, device
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
)
|
)
|
||||||
|
|
||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
|
@ -49,6 +49,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -178,6 +179,33 @@ def get_parser():
|
|||||||
help="The seed for random generators intended for reproducibility",
|
help="The seed for random generators intended for reproducibility",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr",
|
||||||
|
type=float,
|
||||||
|
default=1e-3,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sent-len",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Maximum number of tokens in a sentence. This is used
|
||||||
|
to adjust batch-size dynamically""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every-n",
|
||||||
|
type=int,
|
||||||
|
default=2000,
|
||||||
|
help="""Save checkpoint after processing this number of batches"
|
||||||
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
|
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||||
|
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||||
|
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -190,16 +218,15 @@ def get_params() -> AttributeDict:
|
|||||||
"sos_id": 1,
|
"sos_id": 1,
|
||||||
"eos_id": 1,
|
"eos_id": 1,
|
||||||
"blank_id": 0,
|
"blank_id": 0,
|
||||||
"lr": 1e-3,
|
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 200,
|
"log_interval": 100,
|
||||||
"reset_interval": 2000,
|
"reset_interval": 2000,
|
||||||
"valid_interval": 5000,
|
"valid_interval": 200,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -382,6 +409,7 @@ def train_one_epoch(
|
|||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Train the model for one epoch.
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
@ -430,6 +458,19 @@ def train_one_epoch(
|
|||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
if (
|
||||||
|
params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
|
):
|
||||||
|
save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
global_batch_idx=params.batch_idx_train,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
# Note: "frames" here means "num_tokens"
|
# Note: "frames" here means "num_tokens"
|
||||||
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
||||||
@ -580,6 +621,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
|
0
icefall/transformer_lm/__init__.py
Normal file
0
icefall/transformer_lm/__init__.py
Normal file
@ -1378,7 +1378,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
|||||||
List of timestamp of each word.
|
List of timestamp of each word.
|
||||||
"""
|
"""
|
||||||
start_token = b"\xe2\x96\x81".decode() # '_'
|
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||||
assert len(tokens) == len(timestamp)
|
assert len(tokens) == len(timestamp), (len(tokens), len(timestamp))
|
||||||
ans = []
|
ans = []
|
||||||
for i in range(len(tokens)):
|
for i in range(len(tokens)):
|
||||||
flag = False
|
flag = False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user