mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Merge remote-tracking branch 'k2-fsa/master' into new-zipformer-add-ctc
This commit is contained in:
commit
c33ebefaf8
@ -58,6 +58,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -76,6 +77,8 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -211,6 +214,26 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -222,6 +245,7 @@ def decode_one_batch(
|
||||
token_table: k2.SymbolTable,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -285,6 +309,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
else:
|
||||
hyp_tokens = []
|
||||
@ -324,7 +349,12 @@ def decode_one_batch(
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -333,6 +363,7 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -377,6 +408,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
token_table=token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -407,16 +439,17 @@ def save_results(
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
|
||||
store_transcripts(filename=recog_path, texts=results_char)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -457,6 +490,12 @@ def main():
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -470,6 +509,10 @@ def main():
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -490,6 +533,11 @@ def main():
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -586,6 +634,19 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -608,6 +669,7 @@ def main():
|
||||
model=model,
|
||||
token_table=lexicon.token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -63,6 +63,14 @@ log() {
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if ! command -v ffmpeg &> /dev/null; then
|
||||
echo "This dataset requires ffmpeg"
|
||||
echo "Please install ffmpeg first"
|
||||
echo ""
|
||||
echo " sudo apt-get install ffmpeg"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
|
@ -107,7 +107,7 @@ fi
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
# to $dl_dir/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
|
@ -23,7 +23,7 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
@ -765,6 +765,9 @@ class Hypothesis:
|
||||
# N-gram LM state
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
# Context graph state
|
||||
context_state: Optional[ContextState] = None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
@ -917,6 +920,7 @@ def modified_beam_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
@ -968,6 +972,7 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
@ -990,6 +995,7 @@ def modified_beam_search(
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
@ -1047,21 +1053,51 @@ def modified_beam_search(
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
if context_graph is not None:
|
||||
(
|
||||
context_score,
|
||||
new_context_state,
|
||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k] + context_score
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
|
||||
# finalize context_state, if the matched contexts do not reach final state
|
||||
# we need to add the score on the corresponding backoff arc
|
||||
if context_graph is not None:
|
||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||
for i, hyps in enumerate(B):
|
||||
for hyp in list(hyps):
|
||||
context_score, new_context_state = context_graph.finalize(
|
||||
hyp.context_state
|
||||
)
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
log_prob=hyp.log_prob + context_score,
|
||||
timestamp=hyp.timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
)
|
||||
B = finalized_B
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
|
@ -99,7 +99,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless3/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
@ -125,6 +125,7 @@ For example:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -146,6 +147,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -353,6 +355,27 @@ def get_parser():
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -365,6 +388,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -494,6 +518,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
@ -548,7 +573,12 @@ def decode_one_batch(
|
||||
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -558,6 +588,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -622,6 +653,7 @@ def decode_dataset(
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
@ -728,6 +760,12 @@ def main():
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -750,6 +788,10 @@ def main():
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-context-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -881,6 +923,18 @@ def main():
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(sp.encode(contexts))
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -905,6 +959,7 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -78,7 +78,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -115,7 +115,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--exp-dir pruned_transducer_stateless4/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
@ -37,7 +37,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--exp-dir pruned_transducer_stateless4/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
@ -195,7 +195,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -296,7 +296,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -87,7 +87,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
@ -84,7 +84,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
@ -78,7 +78,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -115,7 +115,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
python ./pruned_transducer_stateless5/test_model.py
|
||||
"""
|
||||
|
||||
from train import get_params, get_transducer_model
|
||||
|
@ -328,7 +328,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -20,23 +20,23 @@
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
./pruned_transducer_stateless6/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless6/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless2/decode.py`,
|
||||
To use the generated file with `pruned_transducer_stateless6/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
./pruned_transducer_stateless6/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless6/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 100 \
|
||||
@ -65,7 +65,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -91,7 +91,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless6/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
|
@ -267,7 +267,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -94,10 +94,10 @@ Usage:
|
||||
--max-states 64
|
||||
|
||||
(8) modified beam search with RNNLM shallow fusion
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--beam-size 4 \
|
||||
@ -110,11 +110,11 @@ Usage:
|
||||
--rnn-lm-tie-weights 1
|
||||
|
||||
(9) modified beam search with LM shallow fusion + LODR
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--max-duration 600 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--decoding-method modified_beam_search_LODR \
|
||||
--beam-size 4 \
|
||||
--lm-type rnn \
|
||||
|
@ -79,7 +79,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -116,7 +116,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
|
@ -389,7 +389,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -50,7 +50,6 @@ import copy
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
@ -89,6 +88,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
symlink_or_copy,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -340,7 +340,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -601,7 +601,8 @@ def save_checkpoint(
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
epoch_basename = f"epoch-{params.cur_epoch}.pt"
|
||||
filename = params.exp_dir / epoch_basename
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
@ -615,12 +616,14 @@ def save_checkpoint(
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
symlink_or_copy(
|
||||
exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
|
||||
)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
symlink_or_copy(
|
||||
exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
|
||||
)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
|
@ -346,7 +346,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -342,7 +342,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -90,7 +90,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
@ -88,7 +88,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
@ -39,7 +39,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -65,7 +65,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
|
@ -77,7 +77,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -114,7 +114,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
@ -355,7 +355,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -355,7 +355,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
0
egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py
Normal file → Executable file
0
egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py
Normal file → Executable file
@ -88,7 +88,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
4
egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py
Normal file → Executable file
4
egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py
Normal file → Executable file
@ -78,7 +78,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -115,7 +115,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless7_streaming_multi/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
@ -366,7 +366,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -348,7 +348,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1127,7 +1127,16 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
|
775
egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Executable file
775
egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Executable file
@ -0,0 +1,775 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx-streaming.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 64
|
||||
|
||||
The --chunk-size in training is "16,32,64,-1", so we select one of them
|
||||
(excluding -1) during streaming export. The same applies to `--left-context`,
|
||||
whose value is "64,128,256,-1".
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool, make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(
|
||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder_proj = encoder_proj
|
||||
self.chunk_size = encoder.chunk_size[0]
|
||||
self.left_context_len = encoder.left_context_frames[0]
|
||||
self.pad_length = 7 + 2 * 3
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
N = x.size(0)
|
||||
T = self.chunk_size * 2 + self.pad_length
|
||||
x_lens = torch.tensor([T] * N, device=x.device)
|
||||
left_context_len = self.left_context_len
|
||||
|
||||
cached_embed_left_pad = states[-2]
|
||||
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
)
|
||||
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
|
||||
x = x.permute(1, 0, 2)
|
||||
encoder_states = states[:-2]
|
||||
logging.info(f"len_encoder_states={len(encoder_states)}")
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_encoder_states,
|
||||
) = self.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=encoder_states,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2)
|
||||
encoder_out = self.encoder_proj(encoder_out)
|
||||
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||
|
||||
new_states = new_encoder_states + [
|
||||
new_cached_embed_left_pad,
|
||||
new_processed_lens,
|
||||
]
|
||||
|
||||
return encoder_out, new_states
|
||||
|
||||
def get_init_states(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||
states[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
"""
|
||||
states = self.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
encoder_model.encoder.__class__.forward = (
|
||||
encoder_model.encoder.__class__.streaming_forward
|
||||
)
|
||||
|
||||
decode_chunk_len = encoder_model.chunk_size * 2
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
T = decode_chunk_len + encoder_model.pad_length
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
init_state = encoder_model.get_init_states()
|
||||
num_encoders = len(encoder_model.encoder.encoder_dim)
|
||||
logging.info(f"num_encoders: {num_encoders}")
|
||||
logging.info(f"len(init_state): {len(init_state)}")
|
||||
|
||||
inputs = {}
|
||||
input_names = ["x"]
|
||||
|
||||
outputs = {}
|
||||
output_names = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f"cached_key_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[0].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f"cached_nonlin_attn_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[1].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val1_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[2].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val2_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[3].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv1_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[4].shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv2_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[5].shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
|
||||
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
|
||||
ds = encoder_model.encoder.downsampling_factor
|
||||
left_context_len = encoder_model.left_context_len
|
||||
left_context_len = [left_context_len // k for k in ds]
|
||||
left_context_len = ",".join(map(str, left_context_len))
|
||||
query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim))
|
||||
value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim))
|
||||
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer2",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "streaming zipformer2",
|
||||
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||
"T": str(T), # 32+7+2*3=45
|
||||
"num_encoder_layers": num_encoder_layers,
|
||||
"encoder_dims": encoder_dims,
|
||||
"cnn_module_kernels": cnn_module_kernels,
|
||||
"left_context_len": left_context_len,
|
||||
"query_head_dims": query_head_dims,
|
||||
"value_head_dims": value_head_dims,
|
||||
"num_heads": num_heads,
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
for i in range(len(init_state[:-2]) // 6):
|
||||
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
embed_states = init_state[-2]
|
||||
name = "embed_states"
|
||||
logging.info(f"{name}.shape: {embed_states.shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
processed_lens = init_state[-1]
|
||||
name = "processed_lens"
|
||||
logging.info(f"{name}.shape: {processed_lens.shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
logging.info(inputs)
|
||||
logging.info(outputs)
|
||||
logging.info(input_names)
|
||||
logging.info(output_names)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, init_state),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes={
|
||||
"x": {0: "N"},
|
||||
"encoder_out": {0: "N"},
|
||||
**inputs,
|
||||
**outputs,
|
||||
},
|
||||
)
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_embed=model.encoder_embed,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
624
egs/librispeech/ASR/zipformer/export-onnx.py
Executable file
624
egs/librispeech/ASR/zipformer/export-onnx.py
Executable file
@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
\
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||
use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool, make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(
|
||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of Zipformer.forward
|
||||
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2)
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2)
|
||||
encoder_out = self.encoder_proj(encoder_out)
|
||||
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer2",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "non-streaming zipformer2",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_embed=model.encoder_embed,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -54,7 +54,7 @@ class AsrModel(nn.Module):
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
|
544
egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
Executable file
544
egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
Executable file
@ -0,0 +1,544 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script loads ONNX models exported by ./export-onnx-streaming.py
|
||||
and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx-streaming.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 64
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file with the exported ONNX models
|
||||
|
||||
./zipformer/onnx_pretrained-streaming.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
Note: Even though this script only supports decoding a single file,
|
||||
the exported ONNX models do support batch processing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
def init_encoder_states(self, batch_size: int = 1):
|
||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
logging.info(f"encoder_meta={encoder_meta}")
|
||||
|
||||
model_type = encoder_meta["model_type"]
|
||||
assert model_type == "zipformer2", model_type
|
||||
|
||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||
T = int(encoder_meta["T"])
|
||||
|
||||
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
||||
encoder_dims = encoder_meta["encoder_dims"]
|
||||
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
|
||||
left_context_len = encoder_meta["left_context_len"]
|
||||
query_head_dims = encoder_meta["query_head_dims"]
|
||||
value_head_dims = encoder_meta["value_head_dims"]
|
||||
num_heads = encoder_meta["num_heads"]
|
||||
|
||||
def to_int_list(s):
|
||||
return list(map(int, s.split(",")))
|
||||
|
||||
num_encoder_layers = to_int_list(num_encoder_layers)
|
||||
encoder_dims = to_int_list(encoder_dims)
|
||||
cnn_module_kernels = to_int_list(cnn_module_kernels)
|
||||
left_context_len = to_int_list(left_context_len)
|
||||
query_head_dims = to_int_list(query_head_dims)
|
||||
value_head_dims = to_int_list(value_head_dims)
|
||||
num_heads = to_int_list(num_heads)
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||
logging.info(f"encoder_dims: {encoder_dims}")
|
||||
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
|
||||
logging.info(f"left_context_len: {left_context_len}")
|
||||
logging.info(f"query_head_dims: {query_head_dims}")
|
||||
logging.info(f"value_head_dims: {value_head_dims}")
|
||||
logging.info(f"num_heads: {num_heads}")
|
||||
|
||||
num_encoders = len(num_encoder_layers)
|
||||
|
||||
self.states = []
|
||||
for i in range(num_encoders):
|
||||
num_layers = num_encoder_layers[i]
|
||||
key_dim = query_head_dims[i] * num_heads[i]
|
||||
embed_dim = encoder_dims[i]
|
||||
nonlin_attn_head_dim = 3 * embed_dim // 4
|
||||
value_dim = value_head_dims[i] * num_heads[i]
|
||||
conv_left_pad = cnn_module_kernels[i] // 2
|
||||
|
||||
for layer in range(num_layers):
|
||||
cached_key = torch.zeros(
|
||||
left_context_len[i], batch_size, key_dim
|
||||
).numpy()
|
||||
cached_nonlin_attn = torch.zeros(
|
||||
1, batch_size, left_context_len[i], nonlin_attn_head_dim
|
||||
).numpy()
|
||||
cached_val1 = torch.zeros(
|
||||
left_context_len[i], batch_size, value_dim
|
||||
).numpy()
|
||||
cached_val2 = torch.zeros(
|
||||
left_context_len[i], batch_size, value_dim
|
||||
).numpy()
|
||||
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||
self.states += [
|
||||
cached_key,
|
||||
cached_nonlin_attn,
|
||||
cached_val1,
|
||||
cached_val2,
|
||||
cached_conv1,
|
||||
cached_conv2,
|
||||
]
|
||||
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
|
||||
self.states.append(embed_states)
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
|
||||
self.states.append(processed_lens)
|
||||
|
||||
self.num_encoders = num_encoders
|
||||
|
||||
self.segment = T
|
||||
self.offset = decode_chunk_len
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def _build_encoder_input_output(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||
encoder_input = {"x": x.numpy()}
|
||||
encoder_output = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f"cached_key_{i}"
|
||||
encoder_input[name] = tensors[0]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f"cached_nonlin_attn_{i}"
|
||||
encoder_input[name] = tensors[1]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val1_{i}"
|
||||
encoder_input[name] = tensors[2]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val2_{i}"
|
||||
encoder_input[name] = tensors[3]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv1_{i}"
|
||||
encoder_input[name] = tensors[4]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv2_{i}"
|
||||
encoder_input[name] = tensors[5]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
for i in range(len(self.states[:-2]) // 6):
|
||||
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
name = "embed_states"
|
||||
embed_states = self.states[-2]
|
||||
encoder_input[name] = embed_states
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
name = "processed_lens"
|
||||
processed_lens = self.states[-1]
|
||||
encoder_input[name] = processed_lens
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
return encoder_input, encoder_output
|
||||
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
self.states = states
|
||||
|
||||
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
"""
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||
|
||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
context_size: int,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (1, T, joiner_dim)
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
decoder_out:
|
||||
Optional. Decoder output of the previous chunk.
|
||||
hyp:
|
||||
Decoding results for previous chunks.
|
||||
Returns:
|
||||
Return the decoded results so far.
|
||||
"""
|
||||
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert hyp is not None, hyp
|
||||
|
||||
encoder_out = encoder_out.squeeze(0)
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t : t + 1]
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
waves = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
wave_samples = torch.cat([waves, tail_padding])
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.segment
|
||||
offset = model.offset
|
||||
|
||||
context_size = model.context_size
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
chunk = int(1 * sample_rate) # 1 second
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
frames = frames.unsqueeze(0)
|
||||
encoder_out = model.run_encoder(frames)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model,
|
||||
encoder_out,
|
||||
context_size,
|
||||
decoder_out,
|
||||
hyp,
|
||||
)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(text)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/onnx_pretrained.py
|
@ -26,6 +26,18 @@ import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
||||
# 14 is not supported. Please feel free to request support or submit
|
||||
# a pull request on PyTorch GitHub.
|
||||
#
|
||||
# The following function is to solve the above error when exporting
|
||||
# models to ONNX via torch.jit.trace()
|
||||
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
return torch.logaddexp(x, y)
|
||||
else:
|
||||
return (x.exp() + y.exp()).log()
|
||||
|
||||
class PiecewiseLinear(object):
|
||||
"""
|
||||
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
||||
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
|
||||
|
||||
def __float__(self):
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return float(self.default)
|
||||
else:
|
||||
ans = self.schedule(self.batch_count)
|
||||
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
def softmax(x: Tensor, dim: int):
|
||||
if not x.requires_grad or torch.jit.is_scripting():
|
||||
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x.softmax(dim=dim)
|
||||
|
||||
return SoftmaxFunction.apply(x, dim)
|
||||
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return x
|
||||
return scale_grad(x, self.alpha)
|
||||
|
||||
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
|
||||
|
||||
|
||||
def _no_op(x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting()):
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
# a no-op function that will have a node in the autograd graph,
|
||||
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-L activation.
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
if not x.requires_grad:
|
||||
return k2.swoosh_l_forward(x)
|
||||
else:
|
||||
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-R activation.
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
if not x.requires_grad:
|
||||
return k2.swoosh_r_forward(x)
|
||||
else:
|
||||
|
@ -27,6 +27,7 @@ from typing import List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
||||
from zipformer import CompactRelPositionalEncoding
|
||||
|
||||
|
||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
|
||||
model: nn.Module,
|
||||
inplace: bool = False,
|
||||
is_pnnx: bool = False,
|
||||
is_onnx: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
is_pnnx:
|
||||
True if we are going to export the model for PNNX.
|
||||
is_onnx:
|
||||
True if we are going to export the model for ONNX.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
||||
d[name] = nn.Identity()
|
||||
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
||||
# We want to recreate the positional encoding vector when
|
||||
# the input changes, so we have to use torch.jit.script()
|
||||
# to replace torch.jit.trace()
|
||||
d[name] = torch.jit.script(m)
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
|
@ -81,7 +81,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return self.forward_internal(x)
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
|
||||
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.out_norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
x_lens = (x_lens - 7) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
|
@ -433,7 +433,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -133,6 +133,7 @@ class Zipformer2(EncoderInterface):
|
||||
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
||||
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
|
||||
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
|
||||
pos_head_dim = _to_tuple(pos_head_dim)
|
||||
@ -258,7 +259,7 @@ class Zipformer2(EncoderInterface):
|
||||
if not self.causal:
|
||||
return -1, -1
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
assert len(self.chunk_size) == 1, self.chunk_size
|
||||
chunk_size = self.chunk_size[0]
|
||||
else:
|
||||
@ -267,7 +268,7 @@ class Zipformer2(EncoderInterface):
|
||||
if chunk_size == -1:
|
||||
left_context_chunks = -1
|
||||
else:
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
assert len(self.left_context_frames) == 1, self.left_context_frames
|
||||
left_context_frames = self.left_context_frames[0]
|
||||
else:
|
||||
@ -301,14 +302,14 @@ class Zipformer2(EncoderInterface):
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
outputs = []
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
feature_masks = [1.0] * len(self.encoder_dim)
|
||||
else:
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# Not support exporting a model for simulating streaming decoding
|
||||
attn_mask = None
|
||||
else:
|
||||
@ -334,7 +335,7 @@ class Zipformer2(EncoderInterface):
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
lengths = (x_lens + 1) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
@ -372,7 +373,7 @@ class Zipformer2(EncoderInterface):
|
||||
# t is frame index, shape (seq_len,)
|
||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||
# c is chunk index for each frame, shape (seq_len,)
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
c = t // chunk_size
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
@ -650,7 +651,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return None
|
||||
batch_size = x.shape[1]
|
||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||
@ -695,7 +696,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_orig = src
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
attention_skip_rate = 0.0
|
||||
else:
|
||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||
@ -713,7 +714,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
|
||||
selected_attn_weights = attn_weights[0:1]
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif not self.training and random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
@ -732,7 +733,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
@ -740,7 +741,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
ff2_skip_rate = 0.0
|
||||
else:
|
||||
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
||||
@ -754,7 +755,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
@ -762,7 +763,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
ff3_skip_rate = 0.0
|
||||
else:
|
||||
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
||||
@ -968,7 +969,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
pos_emb = self.encoder_pos(src)
|
||||
output = src
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
output = output * feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
@ -980,7 +981,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
output = output * feature_mask
|
||||
|
||||
return output
|
||||
@ -1073,7 +1074,7 @@ class BypassModule(nn.Module):
|
||||
# or (batch_size, num_channels,). This is actually the
|
||||
# scale on the non-residual term, so 0 correponds to bypassing
|
||||
# this module.
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return self.bypass_scale
|
||||
else:
|
||||
ans = limit_param_value(self.bypass_scale,
|
||||
@ -1229,12 +1230,11 @@ class SimpleDownsample(torch.nn.Module):
|
||||
d_seq_len = (seq_len + ds - 1) // ds
|
||||
|
||||
# Pad to an exact multiple of self.downsample
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
|
||||
@ -1322,11 +1322,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(0) >= T * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
|
||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||
@ -1524,7 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
use_pos_scores = False
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# We can't put random.random() in the same line
|
||||
use_pos_scores = True
|
||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
@ -1542,16 +1538,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif self.training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be
|
||||
@ -1594,7 +1600,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# half-precision output for backprop purposes.
|
||||
attn_weights = softmax(attn_scores, dim=-1)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif random.random() < 0.001 and not self.training:
|
||||
self._print_attn_entropy(attn_weights)
|
||||
@ -1672,15 +1678,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
pos_scores = torch.matmul(p, pos_emb)
|
||||
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(k_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
@ -2136,7 +2153,7 @@ class ConvolutionModule(nn.Module):
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
if not torch.jit.is_scripting() and chunk_size >= 0:
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
|
||||
# Not support exporting a model for simulated streaming decoding
|
||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
||||
|
1
egs/must_c/ST/local/compute_fbank_musan.py
Symbolic link
1
egs/must_c/ST/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
155
egs/must_c/ST/local/compute_fbank_must_c.py
Executable file
155
egs/must_c/ST/local/compute_fbank_must_c.py
Executable file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This file computes fbank features of the MuST-C dataset.
|
||||
It looks for manifests in the directory "in_dir" and write
|
||||
generated features to "out_dir".
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
FeatureSet,
|
||||
LilcomChunkyWriter,
|
||||
load_manifest,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--in-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Input manifest directory",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory where generated fbank features are saved.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tgt-lang",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target language, e.g., zh, de, fr.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-jobs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of jobs for computing features",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to enable speed perturb with factors 0.9 and 1.1 on
|
||||
the train subset. False (by default) to disable speed perturb.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_must_c(
|
||||
in_dir: Path,
|
||||
out_dir: Path,
|
||||
tgt_lang: str,
|
||||
num_jobs: int,
|
||||
perturb_speed: bool,
|
||||
):
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||
|
||||
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
|
||||
|
||||
prefix = "must_c"
|
||||
suffix = "jsonl.gz"
|
||||
for p in parts:
|
||||
logging.info(f"Processing {p}")
|
||||
|
||||
cuts_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
|
||||
if perturb_speed and p == "train":
|
||||
cuts_path += "_sp"
|
||||
|
||||
cuts_path += ".jsonl.gz"
|
||||
|
||||
if Path(cuts_path).is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
recordings_filename = in_dir / f"{prefix}_recordings_en-{tgt_lang}_{p}.jsonl.gz"
|
||||
supervisions_filename = (
|
||||
in_dir / f"{prefix}_supervisions_en-{tgt_lang}_{p}_norm_rm.jsonl.gz"
|
||||
)
|
||||
assert recordings_filename.is_file(), recordings_filename
|
||||
assert supervisions_filename.is_file(), supervisions_filename
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=load_manifest(recordings_filename),
|
||||
supervisions=load_manifest(supervisions_filename),
|
||||
)
|
||||
if perturb_speed and p == "train":
|
||||
logging.info("Speed perturbing for the train dataset")
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}_sp"
|
||||
else:
|
||||
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=storage_path,
|
||||
num_jobs=num_jobs,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
logging.info(f"Saved to {cuts_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
logging.info(vars(args))
|
||||
assert args.in_dir.is_dir(), args.in_dir
|
||||
|
||||
compute_fbank_must_c(
|
||||
in_dir=args.in_dir,
|
||||
out_dir=args.out_dir,
|
||||
tgt_lang=args.tgt_lang,
|
||||
num_jobs=args.num_jobs,
|
||||
perturb_speed=args.perturb_speed,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
34
egs/must_c/ST/local/get_text.py
Executable file
34
egs/must_c/ST/local/get_text.py
Executable file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This file prints the text field of supervisions from cutset to the console
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Input manifest",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert args.manifest.is_file(), args.manifest
|
||||
|
||||
cutset = load_manifest_lazy(args.manifest)
|
||||
for c in cutset:
|
||||
for sup in c.supervisions:
|
||||
print(sup.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
48
egs/must_c/ST/local/get_words.py
Executable file
48
egs/must_c/ST/local/get_words.py
Executable file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This file generates words.txt from the given transcript file.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"transcript",
|
||||
type=Path,
|
||||
help="Input transcript file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert args.transcript.is_file(), args.transcript
|
||||
|
||||
word_set = set()
|
||||
with open(args.transcript) as f:
|
||||
for line in f:
|
||||
words = line.strip().split()
|
||||
for w in words:
|
||||
word_set.add(w)
|
||||
|
||||
# Note: reserved* should be kept in sync with ./local/prepare_lang_bpe.py
|
||||
reserved1 = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>"]
|
||||
reserved2 = ["#0", "<s>", "</s>"]
|
||||
|
||||
for w in reserved1 + reserved2:
|
||||
assert w not in word_set, w
|
||||
|
||||
words = sorted(list(word_set))
|
||||
words = reserved1 + words + reserved2
|
||||
|
||||
for i, w in enumerate(words):
|
||||
print(w, i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
169
egs/must_c/ST/local/normalize_punctuation.py
Normal file
169
egs/must_c/ST/local/normalize_punctuation.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import re
|
||||
|
||||
|
||||
def normalize_punctuation(s: str, lang: str) -> str:
|
||||
"""
|
||||
This function implements
|
||||
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/normalize-punctuation.perl
|
||||
|
||||
Args:
|
||||
s:
|
||||
A string to be normalized.
|
||||
lang:
|
||||
The language to which `s` belongs
|
||||
Returns:
|
||||
Return a normalized string.
|
||||
"""
|
||||
# s/\r//g;
|
||||
s = re.sub("\r", "", s)
|
||||
|
||||
# remove extra spaces
|
||||
# s/\(/ \(/g;
|
||||
s = re.sub("\(", " (", s) # add a space before (
|
||||
|
||||
# s/\)/\) /g; s/ +/ /g;
|
||||
s = re.sub("\)", ") ", s) # add a space after )
|
||||
s = re.sub(" +", " ", s) # convert multiple spaces to one
|
||||
|
||||
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
|
||||
s = re.sub("\) ([\.\!\:\?\;\,])", r")\1", s)
|
||||
|
||||
# s/\( /\(/g;
|
||||
s = re.sub("\( ", "(", s) # remove space after (
|
||||
|
||||
# s/ \)/\)/g;
|
||||
s = re.sub(" \)", ")", s) # remove space before )
|
||||
|
||||
# s/(\d) \%/$1\%/g;
|
||||
s = re.sub("(\d) \%", r"\1%", s) # remove space between a digit and %
|
||||
|
||||
# s/ :/:/g;
|
||||
s = re.sub(" :", ":", s) # remove space before :
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = re.sub(" ;", ";", s) # remove space before ;
|
||||
|
||||
# normalize unicode punctuation
|
||||
# s/\`/\'/g;
|
||||
s = re.sub("`", "'", s) # replace ` with '
|
||||
|
||||
# s/\'\'/ \" /g;
|
||||
s = re.sub("''", '"', s) # replace '' with "
|
||||
|
||||
# s/„/\"/g;
|
||||
s = re.sub("„", '"', s) # replace „ with "
|
||||
|
||||
# s/“/\"/g;
|
||||
s = re.sub("“", '"', s) # replace “ with "
|
||||
|
||||
# s/”/\"/g;
|
||||
s = re.sub("”", '"', s) # replace ” with "
|
||||
|
||||
# s/–/-/g;
|
||||
s = re.sub("–", "-", s) # replace – with -
|
||||
|
||||
# s/—/ - /g; s/ +/ /g;
|
||||
s = re.sub("—", " - ", s)
|
||||
s = re.sub(" +", " ", s) # convert multiple spaces to one
|
||||
|
||||
# s/´/\'/g;
|
||||
s = re.sub("´", "'", s)
|
||||
|
||||
# s/([a-z])‘([a-z])/$1\'$2/gi;
|
||||
s = re.sub("([a-z])‘([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
|
||||
|
||||
# s/([a-z])’([a-z])/$1\'$2/gi;
|
||||
s = re.sub("([a-z])’([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
|
||||
|
||||
# s/‘/\'/g;
|
||||
s = re.sub("‘", "'", s)
|
||||
|
||||
# s/‚/\'/g;
|
||||
s = re.sub("‚", "'", s)
|
||||
|
||||
# s/’/\"/g;
|
||||
s = re.sub("’", '"', s)
|
||||
|
||||
# s/''/\"/g;
|
||||
s = re.sub("''", '"', s)
|
||||
|
||||
# s/´´/\"/g;
|
||||
s = re.sub("´´", '"', s)
|
||||
|
||||
# s/…/.../g;
|
||||
s = re.sub("…", "...", s)
|
||||
|
||||
# French quotes
|
||||
|
||||
# s/ « / \"/g;
|
||||
s = re.sub(" « ", ' "', s)
|
||||
|
||||
# s/« /\"/g;
|
||||
s = re.sub("« ", '"', s)
|
||||
|
||||
# s/«/\"/g;
|
||||
s = re.sub("«", '"', s)
|
||||
|
||||
# s/ » /\" /g;
|
||||
s = re.sub(" » ", '" ', s)
|
||||
|
||||
# s/ »/\"/g;
|
||||
s = re.sub(" »", '"', s)
|
||||
|
||||
# s/»/\"/g;
|
||||
s = re.sub("»", '"', s)
|
||||
|
||||
# handle pseudo-spaces
|
||||
|
||||
# s/ \%/\%/g;
|
||||
s = re.sub(" %", r"%", s)
|
||||
|
||||
# s/nº /nº /g;
|
||||
s = re.sub("nº ", "nº ", s)
|
||||
|
||||
# s/ :/:/g;
|
||||
s = re.sub(" :", ":", s)
|
||||
|
||||
# s/ ºC/ ºC/g;
|
||||
s = re.sub(" ºC", " ºC", s)
|
||||
|
||||
# s/ cm/ cm/g;
|
||||
s = re.sub(" cm", " cm", s)
|
||||
|
||||
# s/ \?/\?/g;
|
||||
s = re.sub(" \?", "\?", s)
|
||||
|
||||
# s/ \!/\!/g;
|
||||
s = re.sub(" \!", "\!", s)
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = re.sub(" ;", ";", s)
|
||||
|
||||
# s/, /, /g; s/ +/ /g;
|
||||
s = re.sub(", ", ", ", s)
|
||||
s = re.sub(" +", " ", s)
|
||||
|
||||
if lang == "en":
|
||||
# English "quotation," followed by comma, style
|
||||
# s/\"([,\.]+)/$1\"/g;
|
||||
s = re.sub('"([,\.]+)', r'\1"', s)
|
||||
elif lang in ("cs", "cz"):
|
||||
# Czech is confused
|
||||
pass
|
||||
else:
|
||||
# German/Spanish/French "quotation", followed by comma, style
|
||||
# s/,\"/\",/g;
|
||||
s = re.sub(',"', '",', s)
|
||||
|
||||
# s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence
|
||||
s = re.sub('(\.+)"(\s*[^<])', r'"\1\2', s)
|
||||
|
||||
if lang in ("de", "es", "cz", "cs", "fr"):
|
||||
# s/(\d) (\d)/$1,$2/g;
|
||||
s = re.sub("(\d) (\d)", r"\1,\2", s)
|
||||
else:
|
||||
# s/(\d) (\d)/$1.$2/g;
|
||||
s = re.sub("(\d) (\d)", r"\1.\2", s)
|
||||
|
||||
return s
|
1
egs/must_c/ST/local/prepare_lang.py
Symbolic link
1
egs/must_c/ST/local/prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang.py
|
1
egs/must_c/ST/local/prepare_lang_bpe.py
Symbolic link
1
egs/must_c/ST/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
96
egs/must_c/ST/local/preprocess_must_c.py
Executable file
96
egs/must_c/ST/local/preprocess_must_c.py
Executable file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This script normalizes transcripts from supervisions.
|
||||
|
||||
Usage:
|
||||
./local/preprocess_must_c.py \
|
||||
--manifest-dir ./data/manifests/v1.0/ \
|
||||
--tgt-lang de
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from normalize_punctuation import normalize_punctuation
|
||||
from remove_non_native_characters import remove_non_native_characters
|
||||
from remove_punctuation import remove_punctuation
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Manifest directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tgt-lang",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target language, e.g., zh, de, fr.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
|
||||
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
|
||||
remove_non_native_characters_lang = partial(
|
||||
remove_non_native_characters, lang=tgt_lang
|
||||
)
|
||||
|
||||
prefix = "must_c"
|
||||
suffix = "jsonl.gz"
|
||||
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
|
||||
for p in parts:
|
||||
logging.info(f"Processing {p}")
|
||||
name = f"en-{tgt_lang}_{p}"
|
||||
|
||||
# norm: normalization
|
||||
# rm: remove punctuation
|
||||
dst_name = manifest_dir / f"must_c_supervisions_{name}_norm_rm.jsonl.gz"
|
||||
if dst_name.is_file():
|
||||
logging.info(f"{dst_name} exists - skipping")
|
||||
continue
|
||||
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=name,
|
||||
output_dir=manifest_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
types=("supervisions",),
|
||||
)
|
||||
if name not in manifests:
|
||||
raise RuntimeError(f"Processing {p} failed.")
|
||||
|
||||
supervisions = manifests[name]["supervisions"]
|
||||
supervisions = supervisions.transform_text(normalize_punctuation_lang)
|
||||
supervisions = supervisions.transform_text(remove_punctuation)
|
||||
supervisions = supervisions.transform_text(lambda x: x.lower())
|
||||
supervisions = supervisions.transform_text(remove_non_native_characters_lang)
|
||||
supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x))
|
||||
|
||||
supervisions.to_file(dst_name)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
logging.info(vars(args))
|
||||
assert args.manifest_dir.is_dir(), args.manifest_dir
|
||||
|
||||
preprocess_must_c(
|
||||
manifest_dir=args.manifest_dir,
|
||||
tgt_lang=args.tgt_lang,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
21
egs/must_c/ST/local/remove_non_native_characters.py
Executable file
21
egs/must_c/ST/local/remove_non_native_characters.py
Executable file
@ -0,0 +1,21 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def remove_non_native_characters(s: str, lang: str):
|
||||
if lang == "de":
|
||||
# ä -> ae
|
||||
# ö -> oe
|
||||
# ü -> ue
|
||||
# ß -> ss
|
||||
|
||||
s = re.sub("ä", "ae", s)
|
||||
s = re.sub("ö", "oe", s)
|
||||
s = re.sub("ü", "ue", s)
|
||||
s = re.sub("ß", "ss", s)
|
||||
# keep only a-z and spaces
|
||||
# note: ' is removed
|
||||
s = re.sub(r"[^a-z\s]", "", s)
|
||||
|
||||
return s
|
41
egs/must_c/ST/local/remove_punctuation.py
Normal file
41
egs/must_c/ST/local/remove_punctuation.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
def remove_punctuation(s: str) -> str:
|
||||
"""
|
||||
It implements https://github.com/espnet/espnet/blob/master/utils/remove_punctuation.pl
|
||||
"""
|
||||
|
||||
# Remove punctuation except apostrophe
|
||||
# s/<space>/spacemark/g; # for scoring
|
||||
s = re.sub("<space>", "spacemark", s)
|
||||
|
||||
# s/'/apostrophe/g;
|
||||
s = re.sub("'", "apostrophe", s)
|
||||
|
||||
# s/[[:punct:]]//g;
|
||||
s = s.translate(str.maketrans("", "", string.punctuation))
|
||||
# string punctuation returns the following string
|
||||
# !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
|
||||
# See
|
||||
# https://stackoverflow.com/questions/265960/best-way-to-strip-punctuation-from-a-string
|
||||
|
||||
# s/apostrophe/'/g;
|
||||
s = re.sub("apostrophe", "'", s)
|
||||
|
||||
# s/spacemark/<space>/g; # for scoring
|
||||
s = re.sub("spacemark", "<space>", s)
|
||||
|
||||
# remove whitespace
|
||||
# s/\s+/ /g;
|
||||
s = re.sub("\s+", " ", s)
|
||||
|
||||
# s/^\s+//;
|
||||
s = re.sub("^\s+", "", s)
|
||||
|
||||
# s/\s+$//;
|
||||
s = re.sub("\s+$", "", s)
|
||||
|
||||
return s
|
197
egs/must_c/ST/local/test_normalize_punctuation.py
Executable file
197
egs/must_c/ST/local/test_normalize_punctuation.py
Executable file
@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from normalize_punctuation import normalize_punctuation
|
||||
|
||||
|
||||
def test_normalize_punctuation():
|
||||
# s/\r//g;
|
||||
s = "a\r\nb\r\n"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert "\r" not in n
|
||||
assert len(s) - 2 == len(n), (len(s), len(n))
|
||||
|
||||
# s/\(/ \(/g;
|
||||
s = "(ab (c"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == " (ab (c", n
|
||||
|
||||
# s/\)/\) /g;
|
||||
s = "a)b c)"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a) b c) "
|
||||
|
||||
# s/ +/ /g;
|
||||
s = " a b c d "
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == " a b c d "
|
||||
|
||||
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
|
||||
for i in ".!:?;,":
|
||||
s = f"a) {i}"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == f"a){i}"
|
||||
|
||||
# s/\( /\(/g;
|
||||
s = "a( b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a (b", n
|
||||
|
||||
# s/ \)/\)/g;
|
||||
s = "ab ) a"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "ab) a", n
|
||||
|
||||
# s/(\d) \%/$1\%/g;
|
||||
s = "1 %a"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "1%a", n
|
||||
|
||||
# s/ :/:/g;
|
||||
s = "a :"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a:", n
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = "a ;"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a;", n
|
||||
|
||||
# s/\`/\'/g;
|
||||
s = "`a`"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "'a'", n
|
||||
|
||||
# s/\'\'/ \" /g;
|
||||
s = "''a''"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/„/\"/g;
|
||||
s = '„a"'
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/“/\"/g;
|
||||
s = "“a„"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/”/\"/g;
|
||||
s = "“a”"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/–/-/g;
|
||||
s = "a–b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a-b", n
|
||||
|
||||
# s/—/ - /g; s/ +/ /g;
|
||||
s = "a—b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a - b", n
|
||||
|
||||
# s/´/\'/g;
|
||||
s = "a´b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'b", n
|
||||
|
||||
# s/([a-z])‘([a-z])/$1\'$2/gi;
|
||||
for i in "‘’":
|
||||
s = f"a{i}B"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'B", n
|
||||
|
||||
s = f"A{i}B"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "A'B", n
|
||||
|
||||
s = f"A{i}b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "A'b", n
|
||||
|
||||
# s/‘/\'/g;
|
||||
# s/‚/\'/g;
|
||||
for i in "‘‚":
|
||||
s = f"a{i}b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'b", n
|
||||
|
||||
# s/’/\"/g;
|
||||
s = "’"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/''/\"/g;
|
||||
s = "''"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/´´/\"/g;
|
||||
s = "´´"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/…/.../g;
|
||||
s = "…"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "...", n
|
||||
|
||||
# s/ « / \"/g;
|
||||
s = "a « b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a "b', n
|
||||
|
||||
# s/« /\"/g;
|
||||
s = "a « b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a "b', n
|
||||
|
||||
# s/«/\"/g;
|
||||
s = "a«b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a"b', n
|
||||
|
||||
# s/ » /\" /g;
|
||||
s = " » "
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '" ', n
|
||||
|
||||
# s/ »/\"/g;
|
||||
s = " »"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/»/\"/g;
|
||||
s = "»"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/ \%/\%/g;
|
||||
s = " %"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "%", n
|
||||
|
||||
# s/ :/:/g;
|
||||
s = " :"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == ":", n
|
||||
|
||||
# s/(\d) (\d)/$1.$2/g;
|
||||
s = "2 3"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "2.3", n
|
||||
|
||||
# s/(\d) (\d)/$1,$2/g;
|
||||
s = "2 3"
|
||||
n = normalize_punctuation(s, lang="de")
|
||||
assert n == "2,3", n
|
||||
|
||||
|
||||
def main():
|
||||
test_normalize_punctuation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
26
egs/must_c/ST/local/test_remove_non_native_characters.py
Executable file
26
egs/must_c/ST/local/test_remove_non_native_characters.py
Executable file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from remove_non_native_characters import remove_non_native_characters
|
||||
|
||||
|
||||
def test_remove_non_native_characters():
|
||||
s = "Ich heiße xxx好的01 fangjun".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "ich heisse xxx fangjun", n
|
||||
|
||||
s = "äÄ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "aeae", n
|
||||
|
||||
s = "öÖ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "oeoe", n
|
||||
|
||||
s = "üÜ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "ueue", n
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_remove_non_native_characters()
|
17
egs/must_c/ST/local/test_remove_punctuation.py
Executable file
17
egs/must_c/ST/local/test_remove_punctuation.py
Executable file
@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from remove_punctuation import remove_punctuation
|
||||
|
||||
|
||||
def test_remove_punctuation():
|
||||
s = "a,b'c!#"
|
||||
n = remove_punctuation(s)
|
||||
assert n == "ab'c", n
|
||||
|
||||
s = " ab " # remove leading and trailing spaces
|
||||
n = remove_punctuation(s)
|
||||
assert n == "ab", n
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_remove_punctuation()
|
1
egs/must_c/ST/local/train_bpe_model.py
Symbolic link
1
egs/must_c/ST/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
1
egs/must_c/ST/local/validate_bpe_lexicon.py
Symbolic link
1
egs/must_c/ST/local/validate_bpe_lexicon.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/validate_bpe_lexicon.py
|
173
egs/must_c/ST/prepare.sh
Executable file
173
egs/must_c/ST/prepare.sh
Executable file
@ -0,0 +1,173 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=10
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
version=v1.0
|
||||
tgt_lang=de
|
||||
dl_dir=$PWD/download
|
||||
|
||||
must_c_dir=$dl_dir/must-c/$version/en-$tgt_lang/data
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files.
|
||||
# - $dl_dir/must-c/$version/en-$tgt_lang/data/{dev,train,tst-COMMON,tst-HE}
|
||||
#
|
||||
# Please go to https://ict.fbk.eu/must-c-releases/
|
||||
# to download and untar the dataset if you have not already done this.
|
||||
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate
|
||||
# data/lang_bpe_${tgt_lang}_xxx
|
||||
# data/lang_bpe_${tgt_lang}_yyy
|
||||
# if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 5000
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ ! -d $must_c_dir ]; then
|
||||
log "$must_c_dir does not exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for d in dev train tst-COMMON tst-HE; do
|
||||
if [ ! -d $must_c_dir/$d ]; then
|
||||
log "$must_c_dir/$d does not exist!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download musan"
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to $dl_dir/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare must-c $version manifest for target language $tgt_lang"
|
||||
mkdir -p data/manifests/$version
|
||||
if [ ! -e data/manifests/$version/.${tgt_lang}.manifests.done ]; then
|
||||
lhotse prepare must-c \
|
||||
-j $nj \
|
||||
--tgt-lang $tgt_lang \
|
||||
$dl_dir/must-c/$version/ \
|
||||
data/manifests/$version/
|
||||
|
||||
touch data/manifests/$version/.${tgt_lang}.manifests.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Text normalization for $version with target language $tgt_lang"
|
||||
if [ ! -f ./data/manifests/$version/.$tgt_lang.norm.done ]; then
|
||||
./local/preprocess_must_c.py \
|
||||
--manifest-dir ./data/manifests/$version/ \
|
||||
--tgt-lang $tgt_lang
|
||||
touch ./data/manifests/$version/.$tgt_lang.norm.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.musan.done ]; then
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for $version with target language $tgt_lang"
|
||||
mkdir -p data/fbank/$version/
|
||||
if [ ! -e data/fbank/$version/.$tgt_lang.done ]; then
|
||||
./local/compute_fbank_must_c.py \
|
||||
--in-dir ./data/manifests/$version/ \
|
||||
--out-dir ./data/fbank/$version/ \
|
||||
--tgt-lang $tgt_lang \
|
||||
--num-jobs $nj
|
||||
|
||||
./local/compute_fbank_must_c.py \
|
||||
--in-dir ./data/manifests/$version/ \
|
||||
--out-dir ./data/fbank/$version/ \
|
||||
--tgt-lang $tgt_lang \
|
||||
--num-jobs $nj \
|
||||
--perturb-speed 1
|
||||
|
||||
touch data/fbank/$version/.$tgt_lang.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare BPE based lang for $version with target language $tgt_lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}/$version/$tgt_lang/
|
||||
mkdir -p $lang_dir
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
./local/get_text.py ./data/fbank/$version/must_c_feats_en-${tgt_lang}_train.jsonl.gz > $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/words.txt ]; then
|
||||
./local/get_words.py $lang_dir/transcript_words.txt > $lang_dir/words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
|
||||
log "Validating $lang_dir/lexicon.txt"
|
||||
./local/validate_bpe_lexicon.py \
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--bpe-model $lang_dir/bpe.model
|
||||
fi
|
||||
done
|
||||
fi
|
1
egs/must_c/ST/shared
Symbolic link
1
egs/must_c/ST/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=300,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
@ -364,7 +364,7 @@ class WenetSpeechAsrDataModule:
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
logging.info("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
|
@ -92,7 +92,7 @@ When training with the L subset, the streaming usage:
|
||||
--causal-convolution 1 \
|
||||
--decode-chunk-size 16 \
|
||||
--left-context 64
|
||||
|
||||
|
||||
(4) modified beam search with RNNLM shallow fusion
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 35 \
|
||||
@ -112,8 +112,10 @@ When training with the L subset, the streaming usage:
|
||||
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -133,7 +135,8 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -307,6 +310,26 @@ def get_parser():
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-shallow-fusion",
|
||||
type=str2bool,
|
||||
@ -362,6 +385,7 @@ def decode_one_batch(
|
||||
lexicon: Lexicon,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
@ -402,14 +426,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
@ -448,6 +471,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -509,7 +533,12 @@ def decode_one_batch(
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -518,6 +547,7 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
@ -567,6 +597,7 @@ def decode_dataset(
|
||||
lexicon=lexicon,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
context_graph=context_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
@ -646,6 +677,12 @@ def main():
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
@ -655,6 +692,10 @@ def main():
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -684,11 +725,15 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
@ -816,6 +861,19 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -833,15 +891,16 @@ def main():
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||
|
||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
|
@ -23,6 +23,8 @@ from .checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
)
|
||||
|
||||
from .context_graph import ContextGraph, ContextState
|
||||
|
||||
from .decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
|
412
icefall/context_graph.py
Normal file
412
icefall/context_graph.py
Normal file
@ -0,0 +1,412 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from collections import deque
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class ContextState:
|
||||
"""The state in ContextGraph"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
token: int,
|
||||
token_score: float,
|
||||
node_score: float,
|
||||
local_node_score: float,
|
||||
is_end: bool,
|
||||
):
|
||||
"""Create a ContextState.
|
||||
|
||||
Args:
|
||||
id:
|
||||
The node id, only for visualization now. A node is in [0, graph.num_nodes).
|
||||
The id of the root node is always 0.
|
||||
token:
|
||||
The token id.
|
||||
score:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
node_score:
|
||||
The accumulated bonus from root of graph to current node, it will be
|
||||
used to calculate the score for fail arc.
|
||||
local_node_score:
|
||||
The accumulated bonus from last ``end_node``(node with is_end true)
|
||||
to current_node, it will be used to calculate the score for fail arc.
|
||||
Node: The local_node_score of a ``end_node`` is 0.
|
||||
is_end:
|
||||
True if current token is the end of a context.
|
||||
"""
|
||||
self.id = id
|
||||
self.token = token
|
||||
self.token_score = token_score
|
||||
self.node_score = node_score
|
||||
self.local_node_score = local_node_score
|
||||
self.is_end = is_end
|
||||
self.next = {}
|
||||
self.fail = None
|
||||
self.output = None
|
||||
|
||||
|
||||
class ContextGraph:
|
||||
"""The ContextGraph is modified from Aho-Corasick which is mainly
|
||||
a Trie with a fail arc for each node.
|
||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
|
||||
of Aho-Corasick algorithm.
|
||||
|
||||
A ContextGraph contains some words / phrases that we expect to boost their
|
||||
scores during decoding. If the substring of a decoded sequence matches the word / phrase
|
||||
in the ContextGraph, we will give the decoded sequence a bonus to make it survive
|
||||
beam search.
|
||||
"""
|
||||
|
||||
def __init__(self, context_score: float):
|
||||
"""Initialize a ContextGraph with the given ``context_score``.
|
||||
|
||||
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
|
||||
|
||||
Args:
|
||||
context_score:
|
||||
The bonus score for each token(note: NOT for each word/phrase, it means longer
|
||||
word/phrase will have larger bonus score, they have to be matched though).
|
||||
"""
|
||||
self.context_score = context_score
|
||||
self.num_nodes = 0
|
||||
self.root = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=-1,
|
||||
token_score=0,
|
||||
node_score=0,
|
||||
local_node_score=0,
|
||||
is_end=False,
|
||||
)
|
||||
self.root.fail = self.root
|
||||
|
||||
def _fill_fail_output(self):
|
||||
"""This function fills the fail arc for each trie node, it can be computed
|
||||
in linear time by performing a breadth-first search starting from the root.
|
||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||
details of the algorithm.
|
||||
"""
|
||||
queue = deque()
|
||||
for token, node in self.root.next.items():
|
||||
node.fail = self.root
|
||||
queue.append(node)
|
||||
while queue:
|
||||
current_node = queue.popleft()
|
||||
for token, node in current_node.next.items():
|
||||
fail = current_node.fail
|
||||
if token in fail.next:
|
||||
fail = fail.next[token]
|
||||
else:
|
||||
fail = fail.fail
|
||||
while token not in fail.next:
|
||||
fail = fail.fail
|
||||
if fail.token == -1: # root
|
||||
break
|
||||
if token in fail.next:
|
||||
fail = fail.next[token]
|
||||
node.fail = fail
|
||||
# fill the output arc
|
||||
output = node.fail
|
||||
while not output.is_end:
|
||||
output = output.fail
|
||||
if output.token == -1: # root
|
||||
output = None
|
||||
break
|
||||
node.output = output
|
||||
queue.append(node)
|
||||
|
||||
def build(self, token_ids: List[List[int]]):
|
||||
"""Build the ContextGraph from a list of token list.
|
||||
It first build a trie from the given token lists, then fill the fail arc
|
||||
for each trie node.
|
||||
|
||||
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
|
||||
|
||||
Args:
|
||||
token_ids:
|
||||
The given token lists to build the ContextGraph, it is a list of token list,
|
||||
each token list contains the token ids for a word/phrase. The token id
|
||||
could be an id of a char (modeling with single Chinese char) or an id
|
||||
of a BPE (modeling with BPEs).
|
||||
"""
|
||||
for tokens in token_ids:
|
||||
node = self.root
|
||||
for i, token in enumerate(tokens):
|
||||
if token not in node.next:
|
||||
self.num_nodes += 1
|
||||
is_end = i == len(tokens) - 1
|
||||
node.next[token] = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=token,
|
||||
token_score=self.context_score,
|
||||
node_score=node.node_score + self.context_score,
|
||||
local_node_score=0
|
||||
if is_end
|
||||
else (node.local_node_score + self.context_score),
|
||||
is_end=is_end,
|
||||
)
|
||||
node = node.next[token]
|
||||
self._fill_fail_output()
|
||||
|
||||
def forward_one_step(
|
||||
self, state: ContextState, token: int
|
||||
) -> Tuple[float, ContextState]:
|
||||
"""Search the graph with given state and token.
|
||||
|
||||
Args:
|
||||
state:
|
||||
The given token containing trie node to start.
|
||||
token:
|
||||
The given token.
|
||||
|
||||
Returns:
|
||||
Return a tuple of score and next state.
|
||||
"""
|
||||
node = None
|
||||
score = 0
|
||||
# token matched
|
||||
if token in state.next:
|
||||
node = state.next[token]
|
||||
score = node.token_score
|
||||
if state.is_end:
|
||||
score += state.node_score
|
||||
else:
|
||||
# token not matched
|
||||
# We will trace along the fail arc until it matches the token or reaching
|
||||
# root of the graph.
|
||||
node = state.fail
|
||||
while token not in node.next:
|
||||
node = node.fail
|
||||
if node.token == -1: # root
|
||||
break
|
||||
|
||||
if token in node.next:
|
||||
node = node.next[token]
|
||||
|
||||
# The score of the fail path
|
||||
score = node.node_score - state.local_node_score
|
||||
assert node is not None
|
||||
matched_score = 0
|
||||
output = node.output
|
||||
while output is not None:
|
||||
matched_score += output.node_score
|
||||
output = output.output
|
||||
return (score + matched_score, node)
|
||||
|
||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||
"""When reaching the end of the decoded sequence, we need to finalize
|
||||
the matching, the purpose is to subtract the added bonus score for the
|
||||
state that is not the end of a word/phrase.
|
||||
|
||||
Args:
|
||||
state:
|
||||
The given state(trie node).
|
||||
|
||||
Returns:
|
||||
Return a tuple of score and next state. If state is the end of a word/phrase
|
||||
the score is zero, otherwise the score is the score of a implicit fail arc
|
||||
to root. The next state is always root.
|
||||
"""
|
||||
# The score of the fail arc
|
||||
score = -state.node_score
|
||||
if state.is_end:
|
||||
score = 0
|
||||
return (score, self.root)
|
||||
|
||||
def draw(
|
||||
self,
|
||||
title: Optional[str] = None,
|
||||
filename: Optional[str] = "",
|
||||
symbol_table: Optional[Dict[int, str]] = None,
|
||||
) -> "Digraph": # noqa
|
||||
|
||||
"""Visualize a ContextGraph via graphviz.
|
||||
|
||||
Render ContextGraph as an image via graphviz, and return the Digraph object;
|
||||
and optionally save to file `filename`.
|
||||
`filename` must have a suffix that graphviz understands, such as
|
||||
`pdf`, `svg` or `png`.
|
||||
|
||||
Note:
|
||||
You need to install graphviz to use this function::
|
||||
|
||||
pip install graphviz
|
||||
|
||||
Args:
|
||||
title:
|
||||
Title to be displayed in image, e.g. 'A simple FSA example'
|
||||
filename:
|
||||
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
|
||||
'foo.png' (must have a suffix that graphviz understands).
|
||||
symbol_table:
|
||||
Map the token ids to symbols.
|
||||
Returns:
|
||||
A Diagraph from grahpviz.
|
||||
"""
|
||||
|
||||
try:
|
||||
import graphviz
|
||||
except Exception:
|
||||
print("You cannot use `to_dot` unless the graphviz package is installed.")
|
||||
raise
|
||||
|
||||
graph_attr = {
|
||||
"rankdir": "LR",
|
||||
"size": "8.5,11",
|
||||
"center": "1",
|
||||
"orientation": "Portrait",
|
||||
"ranksep": "0.4",
|
||||
"nodesep": "0.25",
|
||||
}
|
||||
if title is not None:
|
||||
graph_attr["label"] = title
|
||||
|
||||
default_node_attr = {
|
||||
"shape": "circle",
|
||||
"style": "bold",
|
||||
"fontsize": "14",
|
||||
}
|
||||
|
||||
final_state_attr = {
|
||||
"shape": "doublecircle",
|
||||
"style": "bold",
|
||||
"fontsize": "14",
|
||||
}
|
||||
|
||||
final_state = -1
|
||||
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
|
||||
|
||||
seen = set()
|
||||
queue = deque()
|
||||
queue.append(self.root)
|
||||
# root id is always 0
|
||||
dot.node("0", label="0", **default_node_attr)
|
||||
dot.edge("0", "0", color="red")
|
||||
seen.add(0)
|
||||
|
||||
while len(queue):
|
||||
current_node = queue.popleft()
|
||||
for token, node in current_node.next.items():
|
||||
if node.id not in seen:
|
||||
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
||||
local_node_score = f"{node.local_node_score:.2f}".rstrip(
|
||||
"0"
|
||||
).rstrip(".")
|
||||
label = f"{node.id}/({node_score},{local_node_score})"
|
||||
if node.is_end:
|
||||
dot.node(str(node.id), label=label, **final_state_attr)
|
||||
else:
|
||||
dot.node(str(node.id), label=label, **default_node_attr)
|
||||
seen.add(node.id)
|
||||
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
|
||||
label = str(token) if symbol_table is None else symbol_table[token]
|
||||
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
|
||||
dot.edge(
|
||||
str(node.id),
|
||||
str(node.fail.id),
|
||||
color="red",
|
||||
)
|
||||
if node.output is not None:
|
||||
dot.edge(
|
||||
str(node.id),
|
||||
str(node.output.id),
|
||||
color="green",
|
||||
)
|
||||
queue.append(node)
|
||||
|
||||
if filename:
|
||||
_, extension = os.path.splitext(filename)
|
||||
if extension == "" or extension[0] != ".":
|
||||
raise ValueError(
|
||||
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
|
||||
filename
|
||||
)
|
||||
)
|
||||
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_fn = dot.render(
|
||||
filename="temp",
|
||||
directory=tmp_dir,
|
||||
format=extension[1:],
|
||||
cleanup=True,
|
||||
)
|
||||
|
||||
shutil.move(temp_fn, filename)
|
||||
|
||||
return dot
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
contexts_str = [
|
||||
"S",
|
||||
"HE",
|
||||
"SHE",
|
||||
"SHELL",
|
||||
"HIS",
|
||||
"HERS",
|
||||
"HELLO",
|
||||
"THIS",
|
||||
"THEM",
|
||||
]
|
||||
contexts = []
|
||||
for s in contexts_str:
|
||||
contexts.append([ord(x) for x in s])
|
||||
|
||||
context_graph = ContextGraph(context_score=1)
|
||||
context_graph.build(contexts)
|
||||
|
||||
symbol_table = {}
|
||||
for contexts in contexts_str:
|
||||
for s in contexts:
|
||||
symbol_table[ord(s)] = s
|
||||
|
||||
context_graph.draw(
|
||||
title="Graph for: " + " / ".join(contexts_str),
|
||||
filename="context_graph.pdf",
|
||||
symbol_table=symbol_table,
|
||||
)
|
||||
|
||||
queries = {
|
||||
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
||||
"SHED": 6, # "S", "SHE", "HE"
|
||||
"HELL": 2, # "HE"
|
||||
"HELLO": 7, # "HE", "HELLO"
|
||||
"DHRHISQ": 4, # "HIS", "S"
|
||||
"THEN": 2, # "HE"
|
||||
}
|
||||
for query, expected_score in queries.items():
|
||||
total_scores = 0
|
||||
state = context_graph.root
|
||||
for q in query:
|
||||
score, state = context_graph.forward_one_step(state, ord(q))
|
||||
total_scores += score
|
||||
score, state = context_graph.finalize(state)
|
||||
assert state.token == -1, state.token
|
||||
total_scores += score
|
||||
assert total_scores == expected_score, (
|
||||
total_scores,
|
||||
expected_score,
|
||||
query,
|
||||
)
|
@ -28,6 +28,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
@ -1881,3 +1882,20 @@ def is_cjk(character):
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def symlink_or_copy(exp_dir: Path, src: str, dst: str):
|
||||
"""
|
||||
In the experiment directory, create a symlink pointing to src named dst.
|
||||
If symlink creation fails (Windows?), fall back to copyfile."""
|
||||
|
||||
dir_fd = os.open(exp_dir, os.O_RDONLY)
|
||||
try:
|
||||
os.remove(dst, dir_fd=dir_fd)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
try:
|
||||
os.symlink(src=src, dst=dst, dir_fd=dir_fd)
|
||||
except OSError:
|
||||
copyfile(src=exp_dir / src, dst=exp_dir / dst)
|
||||
os.close(dir_fd)
|
||||
|
Loading…
x
Reference in New Issue
Block a user