Merge remote-tracking branch 'k2-fsa/master' into new-zipformer-add-ctc

This commit is contained in:
yaozengwei 2023-06-14 11:36:20 +08:00
commit c33ebefaf8
65 changed files with 3764 additions and 135 deletions

View File

@ -58,6 +58,7 @@ Usage:
import argparse
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -76,6 +77,8 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -211,6 +214,26 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser)
return parser
@ -222,6 +245,7 @@ def decode_one_batch(
token_table: k2.SymbolTable,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -285,6 +309,7 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
)
else:
hyp_tokens = []
@ -324,7 +349,12 @@ def decode_one_batch(
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: hyps}
def decode_dataset(
@ -333,6 +363,7 @@ def decode_dataset(
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -377,6 +408,7 @@ def decode_dataset(
model=model,
token_table=token_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
batch=batch,
)
@ -407,16 +439,17 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
store_transcripts(filename=recog_path, texts=results_char)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -457,6 +490,12 @@ def main():
"fast_beam_search",
"modified_beam_search",
)
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
@ -470,6 +509,10 @@ def main():
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-contexts-words"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -490,6 +533,11 @@ def main():
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
@ -586,6 +634,19 @@ def main():
else:
decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts_text = []
for line in open(params.context_file).readlines():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -608,6 +669,7 @@ def main():
model=model,
token_table=lexicon.token_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
)
save_results(

View File

@ -63,6 +63,14 @@ log() {
log "dl_dir: $dl_dir"
if ! command -v ffmpeg &> /dev/null; then
echo "This dataset requires ffmpeg"
echo "Please install ffmpeg first"
echo ""
echo " sudo apt-get install ffmpeg"
exit 1
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"

View File

@ -107,7 +107,7 @@ fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
# to $dl_dir/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests

View File

@ -23,7 +23,7 @@ import k2
import sentencepiece as spm
import torch
from icefall import NgramLm, NgramLmStateCost
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel
@ -765,6 +765,9 @@ class Hypothesis:
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None
# Context graph state
context_state: Optional[ContextState] = None
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
@ -917,6 +920,7 @@ def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_graph: Optional[ContextGraph] = None,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
@ -968,6 +972,7 @@ def modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=None if context_graph is None else context_graph.root,
timestamp=[],
)
)
@ -990,6 +995,7 @@ def modified_beam_search(
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
@ -1047,21 +1053,51 @@ def modified_beam_search(
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
if context_graph is not None:
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
new_log_prob = topk_log_probs[k] + context_score
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
context_state=new_context_state,
)
B[i].add(new_hyp)
B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -99,7 +99,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless3/exp",
help="The experiment dir",
)

View File

@ -125,6 +125,7 @@ For example:
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -146,6 +147,7 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -353,6 +355,27 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser)
return parser
@ -365,6 +388,7 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -494,6 +518,7 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
return_timestamps=True,
)
else:
@ -548,7 +573,12 @@ def decode_one_batch(
return {key: (hyps, timestamps)}
else:
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: (hyps, timestamps)}
def decode_dataset(
@ -558,6 +588,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
"""Decode dataset.
@ -622,6 +653,7 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
context_graph=context_graph,
)
for name, (hyps, timestamps_hyp) in hyps_dict.items():
@ -728,6 +760,12 @@ def main():
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
@ -750,6 +788,10 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-context-words"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -881,6 +923,18 @@ def main():
decoding_graph = None
word_table = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -905,6 +959,7 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
)
save_results(

View File

@ -78,7 +78,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -115,7 +115,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless4/exp",
help="The experiment dir",
)

View File

@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \
--max-duration 300
@ -37,7 +37,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \
--max-duration 550
@ -195,7 +195,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless4/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -296,7 +296,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -87,7 +87,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)

View File

@ -84,7 +84,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)

View File

@ -78,7 +78,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -115,7 +115,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)

View File

@ -20,7 +20,7 @@
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless4/test_model.py
python ./pruned_transducer_stateless5/test_model.py
"""
from train import get_params, get_transducer_model

View File

@ -328,7 +328,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -20,23 +20,23 @@
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
./pruned_transducer_stateless6/export.py \
--exp-dir ./pruned_transducer_stateless6/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless2/decode.py`,
To use the generated file with `pruned_transducer_stateless6/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless2/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \
./pruned_transducer_stateless6/decode.py \
--exp-dir ./pruned_transducer_stateless6/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
@ -65,7 +65,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -91,7 +91,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless6/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",

View File

@ -267,7 +267,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -94,10 +94,10 @@ Usage:
--max-states 64
(8) modified beam search with RNNLM shallow fusion
./pruned_transducer_stateless5/decode.py \
./pruned_transducer_stateless7/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam-size 4 \
@ -110,11 +110,11 @@ Usage:
--rnn-lm-tie-weights 1
(9) modified beam search with LM shallow fusion + LODR
./pruned_transducer_stateless5/decode.py \
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--max-duration 600 \
--exp-dir ./pruned_transducer_stateless5/exp \
--exp-dir ./pruned_transducer_stateless7/exp \
--decoding-method modified_beam_search_LODR \
--beam-size 4 \
--lm-type rnn \

View File

@ -79,7 +79,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -116,7 +116,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
default="pruned_transducer_stateless7/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",

View File

@ -389,7 +389,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -50,7 +50,6 @@ import copy
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
@ -89,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch,
setup_logger,
str2bool,
symlink_or_copy,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -340,7 +340,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)
@ -601,7 +601,8 @@ def save_checkpoint(
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
epoch_basename = f"epoch-{params.cur_epoch}.pt"
filename = params.exp_dir / epoch_basename
save_checkpoint_impl(
filename=filename,
model=model,
@ -615,12 +616,14 @@ def save_checkpoint(
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
symlink_or_copy(
exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
symlink_or_copy(
exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
)
def compute_loss(

View File

@ -346,7 +346,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -342,7 +342,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -90,7 +90,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)

View File

@ -88,7 +88,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)

View File

@ -39,7 +39,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -65,7 +65,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",

View File

@ -77,7 +77,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -114,7 +114,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless7_streaming/exp",
help="The experiment dir",
)

View File

@ -355,7 +355,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -355,7 +355,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -88,7 +88,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)

View File

@ -78,7 +78,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -115,7 +115,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless7_streaming_multi/exp",
help="The experiment dir",
)

View File

@ -366,7 +366,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -348,7 +348,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)
@ -1127,7 +1127,16 @@ def run(rank, world_size, args):
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in model.named_parameters()]
)
optimizer = ScaledAdam(
model.parameters(),
lr=params.base_lr,
clipping_scale=2.0,
parameters_names=parameters_names,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)

View File

@ -0,0 +1,775 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool, make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
self.pad_length = 7 + 2 * 3
def forward(
self,
x: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
N = x.size(0)
T = self.chunk_size * 2 + self.pad_length
x_lens = torch.tensor([T] * N, device=x.device)
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=x,
x_lens=x_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2)
encoder_states = states[:-2]
logging.info(f"len_encoder_states={len(encoder_states)}")
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, new_states
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
states.append(processed_lens)
return states
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
)
decode_chunk_len = encoder_model.chunk_size * 2
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
logging.info(f"len(init_state): {len(init_state)}")
inputs = {}
input_names = ["x"]
outputs = {}
output_names = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
logging.info(f"{name}.shape: {tensors[0].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
logging.info(f"{name}.shape: {tensors[1].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
logging.info(f"{name}.shape: {tensors[2].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
logging.info(f"{name}.shape: {tensors[3].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
logging.info(f"{name}.shape: {tensors[4].shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
logging.info(f"{name}.shape: {tensors[5].shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
ds = encoder_model.encoder.downsampling_factor
left_context_len = encoder_model.left_context_len
left_context_len = [left_context_len // k for k in ds]
left_context_len = ",".join(map(str, left_context_len))
query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim))
value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim))
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "streaming zipformer2",
"decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 32+7+2*3=45
"num_encoder_layers": num_encoder_layers,
"encoder_dims": encoder_dims,
"cnn_module_kernels": cnn_module_kernels,
"left_context_len": left_context_len,
"query_head_dims": query_head_dims,
"value_head_dims": value_head_dims,
"num_heads": num_heads,
}
logging.info(f"meta_data: {meta_data}")
for i in range(len(init_state[:-2]) // 6):
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
embed_states = init_state[-2]
name = "embed_states"
logging.info(f"{name}.shape: {embed_states.shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size,)
processed_lens = init_state[-1]
name = "processed_lens"
logging.info(f"{name}.shape: {processed_lens.shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
logging.info(inputs)
logging.info(outputs)
logging.info(input_names)
logging.info(output_names)
torch.onnx.export(
encoder_model,
(x, init_state),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"x": {0: "N"},
"encoder_out": {0: "N"},
**inputs,
**outputs,
},
)
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,624 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
\
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool, make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "non-streaming zipformer2",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -54,7 +54,7 @@ class AsrModel(nn.Module):
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape

View File

@ -0,0 +1,544 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script loads ONNX models exported by ./export-onnx-streaming.py
and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file with the exported ONNX models
./zipformer/onnx_pretrained-streaming.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
"""
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
logging.info(f"encoder_meta={encoder_meta}")
model_type = encoder_meta["model_type"]
assert model_type == "zipformer2", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
num_encoder_layers = encoder_meta["num_encoder_layers"]
encoder_dims = encoder_meta["encoder_dims"]
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
left_context_len = encoder_meta["left_context_len"]
query_head_dims = encoder_meta["query_head_dims"]
value_head_dims = encoder_meta["value_head_dims"]
num_heads = encoder_meta["num_heads"]
def to_int_list(s):
return list(map(int, s.split(",")))
num_encoder_layers = to_int_list(num_encoder_layers)
encoder_dims = to_int_list(encoder_dims)
cnn_module_kernels = to_int_list(cnn_module_kernels)
left_context_len = to_int_list(left_context_len)
query_head_dims = to_int_list(query_head_dims)
value_head_dims = to_int_list(value_head_dims)
num_heads = to_int_list(num_heads)
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"encoder_dims: {encoder_dims}")
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
logging.info(f"left_context_len: {left_context_len}")
logging.info(f"query_head_dims: {query_head_dims}")
logging.info(f"value_head_dims: {value_head_dims}")
logging.info(f"num_heads: {num_heads}")
num_encoders = len(num_encoder_layers)
self.states = []
for i in range(num_encoders):
num_layers = num_encoder_layers[i]
key_dim = query_head_dims[i] * num_heads[i]
embed_dim = encoder_dims[i]
nonlin_attn_head_dim = 3 * embed_dim // 4
value_dim = value_head_dims[i] * num_heads[i]
conv_left_pad = cnn_module_kernels[i] // 2
for layer in range(num_layers):
cached_key = torch.zeros(
left_context_len[i], batch_size, key_dim
).numpy()
cached_nonlin_attn = torch.zeros(
1, batch_size, left_context_len[i], nonlin_attn_head_dim
).numpy()
cached_val1 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_val2 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
self.states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
self.states.append(processed_lens)
self.num_encoders = num_encoders
self.segment = T
self.offset = decode_chunk_len
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def _build_encoder_input_output(
self,
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
encoder_input = {"x": x.numpy()}
encoder_output = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
encoder_input[name] = tensors[0]
encoder_output.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
encoder_input[name] = tensors[1]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
encoder_input[name] = tensors[2]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
encoder_input[name] = tensors[3]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
encoder_input[name] = tensors[4]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
encoder_input[name] = tensors[5]
encoder_output.append(f"new_{name}")
for i in range(len(self.states[:-2]) // 6):
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
name = "embed_states"
embed_states = self.states[-2]
encoder_input[name] = embed_states
encoder_output.append(f"new_{name}")
# (batch_size,)
name = "processed_lens"
processed_lens = self.states[-1]
encoder_input[name] = processed_lens
encoder_output.append(f"new_{name}")
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
self.states = states
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-7)//2+1)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
out = self.encoder.run(encoder_output_names, encoder_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
Returns:
Return the decoded results so far.
"""
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
encoder_out = encoder_out.squeeze(0)
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
chunk = int(1 * sample_rate) # 1 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
encoder_out = model.run_encoder(frames)
hyp, decoder_out = greedy_search(
model,
encoder_out,
context_size,
decoder_out,
hyp,
)
symbol_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/onnx_pretrained.py

View File

@ -26,6 +26,18 @@ import torch.nn as nn
from torch import Tensor
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
# a pull request on PyTorch GitHub.
#
# The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace()
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
if not torch.jit.is_tracing():
return torch.logaddexp(x, y)
else:
return (x.exp() + y.exp()).log()
class PiecewiseLinear(object):
"""
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
def __float__(self):
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return float(self.default)
else:
ans = self.schedule(self.batch_count)
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
def softmax(x: Tensor, dim: int):
if not x.requires_grad or torch.jit.is_scripting():
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
return x.softmax(dim=dim)
return SoftmaxFunction.apply(x, dim)
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return x
return scale_grad(x, self.alpha)
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
def _no_op(x: Tensor) -> Tensor:
if (torch.jit.is_scripting()):
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x
else:
# a no-op function that will have a node in the autograd graph,
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-L activation.
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad:
return k2.swoosh_l_forward(x)
else:
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-R activation.
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
if not x.requires_grad:
return k2.swoosh_r_forward(x)
else:

View File

@ -27,6 +27,7 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
is_onnx: bool = False,
):
"""
Args:
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return:
Return a model without scaled layers.
"""
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items():
if "." in k:

View File

@ -81,7 +81,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)

View File

@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():

View File

@ -433,7 +433,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)

View File

@ -133,6 +133,7 @@ class Zipformer2(EncoderInterface):
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
num_encoder_layers = _to_tuple(num_encoder_layers)
self.num_encoder_layers = num_encoder_layers
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim)
@ -258,7 +259,7 @@ class Zipformer2(EncoderInterface):
if not self.causal:
return -1, -1
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.chunk_size) == 1, self.chunk_size
chunk_size = self.chunk_size[0]
else:
@ -267,7 +268,7 @@ class Zipformer2(EncoderInterface):
if chunk_size == -1:
left_context_chunks = -1
else:
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.left_context_frames) == 1, self.left_context_frames
left_context_frames = self.left_context_frames[0]
else:
@ -301,14 +302,14 @@ class Zipformer2(EncoderInterface):
of frames in `embeddings` before padding.
"""
outputs = []
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
feature_masks = [1.0] * len(self.encoder_dim)
else:
feature_masks = self.get_feature_masks(x)
chunk_size, left_context_chunks = self.get_chunk_info()
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
# Not support exporting a model for simulating streaming decoding
attn_mask = None
else:
@ -334,7 +335,7 @@ class Zipformer2(EncoderInterface):
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2
else:
with warnings.catch_warnings():
@ -372,7 +373,7 @@ class Zipformer2(EncoderInterface):
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
# c is chunk index for each frame, shape (seq_len,)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
c = t // chunk_size
else:
with warnings.catch_warnings():
@ -650,7 +651,7 @@ class Zipformer2EncoderLayer(nn.Module):
)
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return None
batch_size = x.shape[1]
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
@ -695,7 +696,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_orig = src
# dropout rate for non-feedforward submodules
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
attention_skip_rate = 0.0
else:
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
@ -713,7 +714,7 @@ class Zipformer2EncoderLayer(nn.Module):
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif not self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
@ -732,7 +733,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
@ -740,7 +741,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask),
conv_skip_rate)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
ff2_skip_rate = 0.0
else:
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
@ -754,7 +755,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
@ -762,7 +763,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask),
conv_skip_rate)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
ff3_skip_rate = 0.0
else:
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
@ -968,7 +969,7 @@ class Zipformer2Encoder(nn.Module):
pos_emb = self.encoder_pos(src)
output = src
if not torch.jit.is_scripting():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
output = output * feature_mask
for i, mod in enumerate(self.layers):
@ -980,7 +981,7 @@ class Zipformer2Encoder(nn.Module):
src_key_padding_mask=src_key_padding_mask,
)
if not torch.jit.is_scripting():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
output = output * feature_mask
return output
@ -1073,7 +1074,7 @@ class BypassModule(nn.Module):
# or (batch_size, num_channels,). This is actually the
# scale on the non-residual term, so 0 correponds to bypassing
# this module.
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.bypass_scale
else:
ans = limit_param_value(self.bypass_scale,
@ -1229,12 +1230,11 @@ class SimpleDownsample(torch.nn.Module):
d_seq_len = (seq_len + ds - 1) // ds
# Pad to an exact multiple of self.downsample
if seq_len != d_seq_len * ds:
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
@ -1322,11 +1322,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
@ -1524,7 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k)
use_pos_scores = False
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
# We can't put random.random() in the same line
use_pos_scores = True
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
@ -1542,16 +1538,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif self.training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be
@ -1594,7 +1600,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# half-precision output for backprop purposes.
attn_weights = softmax(attn_scores, dim=-1)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif random.random() < 0.001 and not self.training:
self._print_attn_entropy(attn_weights)
@ -1672,15 +1678,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb)
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(k_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
attn_scores = attn_scores + pos_scores
@ -2136,7 +2153,7 @@ class ConvolutionModule(nn.Module):
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if not torch.jit.is_scripting() and chunk_size >= 0:
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
# Not support exporting a model for simulated streaming decoding
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
x = self.depthwise_conv(x, chunk_size=chunk_size)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,155 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file computes fbank features of the MuST-C dataset.
It looks for manifests in the directory "in_dir" and write
generated features to "out_dir".
"""
import argparse
import logging
from pathlib import Path
import torch
from lhotse import (
CutSet,
Fbank,
FbankConfig,
FeatureSet,
LilcomChunkyWriter,
load_manifest,
)
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--in-dir",
type=Path,
required=True,
help="Input manifest directory",
)
parser.add_argument(
"--out-dir",
type=Path,
required=True,
help="Output directory where generated fbank features are saved.",
)
parser.add_argument(
"--tgt-lang",
type=str,
required=True,
help="Target language, e.g., zh, de, fr.",
)
parser.add_argument(
"--num-jobs",
type=int,
default=1,
help="Number of jobs for computing features",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="""True to enable speed perturb with factors 0.9 and 1.1 on
the train subset. False (by default) to disable speed perturb.
""",
)
return parser.parse_args()
def compute_fbank_must_c(
in_dir: Path,
out_dir: Path,
tgt_lang: str,
num_jobs: int,
perturb_speed: bool,
):
out_dir.mkdir(parents=True, exist_ok=True)
extractor = Fbank(FbankConfig(num_mel_bins=80))
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
prefix = "must_c"
suffix = "jsonl.gz"
for p in parts:
logging.info(f"Processing {p}")
cuts_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
if perturb_speed and p == "train":
cuts_path += "_sp"
cuts_path += ".jsonl.gz"
if Path(cuts_path).is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
recordings_filename = in_dir / f"{prefix}_recordings_en-{tgt_lang}_{p}.jsonl.gz"
supervisions_filename = (
in_dir / f"{prefix}_supervisions_en-{tgt_lang}_{p}_norm_rm.jsonl.gz"
)
assert recordings_filename.is_file(), recordings_filename
assert supervisions_filename.is_file(), supervisions_filename
cut_set = CutSet.from_manifests(
recordings=load_manifest(recordings_filename),
supervisions=load_manifest(supervisions_filename),
)
if perturb_speed and p == "train":
logging.info("Speed perturbing for the train dataset")
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}_sp"
else:
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=storage_path,
num_jobs=num_jobs,
storage_type=LilcomChunkyWriter,
)
logging.info("About to split cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main():
args = get_args()
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
logging.info(vars(args))
assert args.in_dir.is_dir(), args.in_dir
compute_fbank_must_c(
in_dir=args.in_dir,
out_dir=args.out_dir,
tgt_lang=args.tgt_lang,
num_jobs=args.num_jobs,
perturb_speed=args.perturb_speed,
)
if __name__ == "__main__":
main()

34
egs/must_c/ST/local/get_text.py Executable file
View File

@ -0,0 +1,34 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file prints the text field of supervisions from cutset to the console
"""
import argparse
from lhotse import load_manifest_lazy
from pathlib import Path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Input manifest",
)
return parser.parse_args()
def main():
args = get_args()
assert args.manifest.is_file(), args.manifest
cutset = load_manifest_lazy(args.manifest)
for c in cutset:
for sup in c.supervisions:
print(sup.text)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file generates words.txt from the given transcript file.
"""
import argparse
from pathlib import Path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"transcript",
type=Path,
help="Input transcript file",
)
return parser.parse_args()
def main():
args = get_args()
assert args.transcript.is_file(), args.transcript
word_set = set()
with open(args.transcript) as f:
for line in f:
words = line.strip().split()
for w in words:
word_set.add(w)
# Note: reserved* should be kept in sync with ./local/prepare_lang_bpe.py
reserved1 = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>"]
reserved2 = ["#0", "<s>", "</s>"]
for w in reserved1 + reserved2:
assert w not in word_set, w
words = sorted(list(word_set))
words = reserved1 + words + reserved2
for i, w in enumerate(words):
print(w, i)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,169 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import re
def normalize_punctuation(s: str, lang: str) -> str:
"""
This function implements
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/normalize-punctuation.perl
Args:
s:
A string to be normalized.
lang:
The language to which `s` belongs
Returns:
Return a normalized string.
"""
# s/\r//g;
s = re.sub("\r", "", s)
# remove extra spaces
# s/\(/ \(/g;
s = re.sub("\(", " (", s) # add a space before (
# s/\)/\) /g; s/ +/ /g;
s = re.sub("\)", ") ", s) # add a space after )
s = re.sub(" +", " ", s) # convert multiple spaces to one
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
s = re.sub("\) ([\.\!\:\?\;\,])", r")\1", s)
# s/\( /\(/g;
s = re.sub("\( ", "(", s) # remove space after (
# s/ \)/\)/g;
s = re.sub(" \)", ")", s) # remove space before )
# s/(\d) \%/$1\%/g;
s = re.sub("(\d) \%", r"\1%", s) # remove space between a digit and %
# s/ :/:/g;
s = re.sub(" :", ":", s) # remove space before :
# s/ ;/;/g;
s = re.sub(" ;", ";", s) # remove space before ;
# normalize unicode punctuation
# s/\`/\'/g;
s = re.sub("`", "'", s) # replace ` with '
# s/\'\'/ \" /g;
s = re.sub("''", '"', s) # replace '' with "
# s/„/\"/g;
s = re.sub("", '"', s) # replace „ with "
# s/“/\"/g;
s = re.sub("", '"', s) # replace “ with "
# s/”/\"/g;
s = re.sub("", '"', s) # replace ” with "
# s//-/g;
s = re.sub("", "-", s) # replace with -
# s/—/ - /g; s/ +/ /g;
s = re.sub("", " - ", s)
s = re.sub(" +", " ", s) # convert multiple spaces to one
# s/´/\'/g;
s = re.sub("´", "'", s)
# s/([a-z])([a-z])/$1\'$2/gi;
s = re.sub("([a-z])([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
# s/([a-z])([a-z])/$1\'$2/gi;
s = re.sub("([a-z])([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
# s//\'/g;
s = re.sub("", "'", s)
# s//\'/g;
s = re.sub("", "'", s)
# s//\"/g;
s = re.sub("", '"', s)
# s/''/\"/g;
s = re.sub("''", '"', s)
# s/´´/\"/g;
s = re.sub("´´", '"', s)
# s/…/.../g;
s = re.sub("", "...", s)
# French quotes
# s/ « / \"/g;
s = re.sub(" « ", ' "', s)
# s/« /\"/g;
s = re.sub("« ", '"', s)
# s/«/\"/g;
s = re.sub("«", '"', s)
# s/ » /\" /g;
s = re.sub(" » ", '" ', s)
# s/ »/\"/g;
s = re.sub(" »", '"', s)
# s/»/\"/g;
s = re.sub("»", '"', s)
# handle pseudo-spaces
# s/ \%/\%/g;
s = re.sub(" %", r"%", s)
# s/nº /nº /g;
s = re.sub(" ", "", s)
# s/ :/:/g;
s = re.sub(" :", ":", s)
# s/ ºC/ ºC/g;
s = re.sub(" ºC", " ºC", s)
# s/ cm/ cm/g;
s = re.sub(" cm", " cm", s)
# s/ \?/\?/g;
s = re.sub(" \?", "\?", s)
# s/ \!/\!/g;
s = re.sub(" \!", "\!", s)
# s/ ;/;/g;
s = re.sub(" ;", ";", s)
# s/, /, /g; s/ +/ /g;
s = re.sub(", ", ", ", s)
s = re.sub(" +", " ", s)
if lang == "en":
# English "quotation," followed by comma, style
# s/\"([,\.]+)/$1\"/g;
s = re.sub('"([,\.]+)', r'\1"', s)
elif lang in ("cs", "cz"):
# Czech is confused
pass
else:
# German/Spanish/French "quotation", followed by comma, style
# s/,\"/\",/g;
s = re.sub(',"', '",', s)
# s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence
s = re.sub('(\.+)"(\s*[^<])', r'"\1\2', s)
if lang in ("de", "es", "cz", "cs", "fr"):
# s/(\d) (\d)/$1,$2/g;
s = re.sub("(\d) (\d)", r"\1,\2", s)
else:
# s/(\d) (\d)/$1.$2/g;
s = re.sub("(\d) (\d)", r"\1.\2", s)
return s

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script normalizes transcripts from supervisions.
Usage:
./local/preprocess_must_c.py \
--manifest-dir ./data/manifests/v1.0/ \
--tgt-lang de
"""
import argparse
import logging
import re
from functools import partial
from pathlib import Path
from lhotse.recipes.utils import read_manifests_if_cached
from normalize_punctuation import normalize_punctuation
from remove_non_native_characters import remove_non_native_characters
from remove_punctuation import remove_punctuation
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-dir",
type=Path,
required=True,
help="Manifest directory",
)
parser.add_argument(
"--tgt-lang",
type=str,
required=True,
help="Target language, e.g., zh, de, fr.",
)
return parser.parse_args()
def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
remove_non_native_characters_lang = partial(
remove_non_native_characters, lang=tgt_lang
)
prefix = "must_c"
suffix = "jsonl.gz"
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
for p in parts:
logging.info(f"Processing {p}")
name = f"en-{tgt_lang}_{p}"
# norm: normalization
# rm: remove punctuation
dst_name = manifest_dir / f"must_c_supervisions_{name}_norm_rm.jsonl.gz"
if dst_name.is_file():
logging.info(f"{dst_name} exists - skipping")
continue
manifests = read_manifests_if_cached(
dataset_parts=name,
output_dir=manifest_dir,
prefix=prefix,
suffix=suffix,
types=("supervisions",),
)
if name not in manifests:
raise RuntimeError(f"Processing {p} failed.")
supervisions = manifests[name]["supervisions"]
supervisions = supervisions.transform_text(normalize_punctuation_lang)
supervisions = supervisions.transform_text(remove_punctuation)
supervisions = supervisions.transform_text(lambda x: x.lower())
supervisions = supervisions.transform_text(remove_non_native_characters_lang)
supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x))
supervisions.to_file(dst_name)
def main():
args = get_args()
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
logging.info(vars(args))
assert args.manifest_dir.is_dir(), args.manifest_dir
preprocess_must_c(
manifest_dir=args.manifest_dir,
tgt_lang=args.tgt_lang,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,21 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import re
def remove_non_native_characters(s: str, lang: str):
if lang == "de":
# ä -> ae
# ö -> oe
# ü -> ue
# ß -> ss
s = re.sub("ä", "ae", s)
s = re.sub("ö", "oe", s)
s = re.sub("ü", "ue", s)
s = re.sub("ß", "ss", s)
# keep only a-z and spaces
# note: ' is removed
s = re.sub(r"[^a-z\s]", "", s)
return s

View File

@ -0,0 +1,41 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import re
import string
def remove_punctuation(s: str) -> str:
"""
It implements https://github.com/espnet/espnet/blob/master/utils/remove_punctuation.pl
"""
# Remove punctuation except apostrophe
# s/<space>/spacemark/g; # for scoring
s = re.sub("<space>", "spacemark", s)
# s/'/apostrophe/g;
s = re.sub("'", "apostrophe", s)
# s/[[:punct:]]//g;
s = s.translate(str.maketrans("", "", string.punctuation))
# string punctuation returns the following string
# !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
# See
# https://stackoverflow.com/questions/265960/best-way-to-strip-punctuation-from-a-string
# s/apostrophe/'/g;
s = re.sub("apostrophe", "'", s)
# s/spacemark/<space>/g; # for scoring
s = re.sub("spacemark", "<space>", s)
# remove whitespace
# s/\s+/ /g;
s = re.sub("\s+", " ", s)
# s/^\s+//;
s = re.sub("^\s+", "", s)
# s/\s+$//;
s = re.sub("\s+$", "", s)
return s

View File

@ -0,0 +1,197 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
from normalize_punctuation import normalize_punctuation
def test_normalize_punctuation():
# s/\r//g;
s = "a\r\nb\r\n"
n = normalize_punctuation(s, lang="en")
assert "\r" not in n
assert len(s) - 2 == len(n), (len(s), len(n))
# s/\(/ \(/g;
s = "(ab (c"
n = normalize_punctuation(s, lang="en")
assert n == " (ab (c", n
# s/\)/\) /g;
s = "a)b c)"
n = normalize_punctuation(s, lang="en")
assert n == "a) b c) "
# s/ +/ /g;
s = " a b c d "
n = normalize_punctuation(s, lang="en")
assert n == " a b c d "
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
for i in ".!:?;,":
s = f"a) {i}"
n = normalize_punctuation(s, lang="en")
assert n == f"a){i}"
# s/\( /\(/g;
s = "a( b"
n = normalize_punctuation(s, lang="en")
assert n == "a (b", n
# s/ \)/\)/g;
s = "ab ) a"
n = normalize_punctuation(s, lang="en")
assert n == "ab) a", n
# s/(\d) \%/$1\%/g;
s = "1 %a"
n = normalize_punctuation(s, lang="en")
assert n == "1%a", n
# s/ :/:/g;
s = "a :"
n = normalize_punctuation(s, lang="en")
assert n == "a:", n
# s/ ;/;/g;
s = "a ;"
n = normalize_punctuation(s, lang="en")
assert n == "a;", n
# s/\`/\'/g;
s = "`a`"
n = normalize_punctuation(s, lang="en")
assert n == "'a'", n
# s/\'\'/ \" /g;
s = "''a''"
n = normalize_punctuation(s, lang="en")
assert n == '"a"', n
# s/„/\"/g;
s = '„a"'
n = normalize_punctuation(s, lang="en")
assert n == '"a"', n
# s/“/\"/g;
s = "“a„"
n = normalize_punctuation(s, lang="en")
assert n == '"a"', n
# s/”/\"/g;
s = "“a”"
n = normalize_punctuation(s, lang="en")
assert n == '"a"', n
# s//-/g;
s = "ab"
n = normalize_punctuation(s, lang="en")
assert n == "a-b", n
# s/—/ - /g; s/ +/ /g;
s = "a—b"
n = normalize_punctuation(s, lang="en")
assert n == "a - b", n
# s/´/\'/g;
s = "a´b"
n = normalize_punctuation(s, lang="en")
assert n == "a'b", n
# s/([a-z])([a-z])/$1\'$2/gi;
for i in "":
s = f"a{i}B"
n = normalize_punctuation(s, lang="en")
assert n == "a'B", n
s = f"A{i}B"
n = normalize_punctuation(s, lang="en")
assert n == "A'B", n
s = f"A{i}b"
n = normalize_punctuation(s, lang="en")
assert n == "A'b", n
# s//\'/g;
# s//\'/g;
for i in "":
s = f"a{i}b"
n = normalize_punctuation(s, lang="en")
assert n == "a'b", n
# s//\"/g;
s = ""
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/''/\"/g;
s = "''"
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/´´/\"/g;
s = "´´"
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/…/.../g;
s = ""
n = normalize_punctuation(s, lang="en")
assert n == "...", n
# s/ « / \"/g;
s = "a « b"
n = normalize_punctuation(s, lang="en")
assert n == 'a "b', n
# s/« /\"/g;
s = "a « b"
n = normalize_punctuation(s, lang="en")
assert n == 'a "b', n
# s/«/\"/g;
s = "a«b"
n = normalize_punctuation(s, lang="en")
assert n == 'a"b', n
# s/ » /\" /g;
s = " » "
n = normalize_punctuation(s, lang="en")
assert n == '" ', n
# s/ »/\"/g;
s = " »"
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/»/\"/g;
s = "»"
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/ \%/\%/g;
s = " %"
n = normalize_punctuation(s, lang="en")
assert n == "%", n
# s/ :/:/g;
s = " :"
n = normalize_punctuation(s, lang="en")
assert n == ":", n
# s/(\d) (\d)/$1.$2/g;
s = "2 3"
n = normalize_punctuation(s, lang="en")
assert n == "2.3", n
# s/(\d) (\d)/$1,$2/g;
s = "2 3"
n = normalize_punctuation(s, lang="de")
assert n == "2,3", n
def main():
test_normalize_punctuation()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,26 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
from remove_non_native_characters import remove_non_native_characters
def test_remove_non_native_characters():
s = "Ich heiße xxx好的01 fangjun".lower()
n = remove_non_native_characters(s, lang="de")
assert n == "ich heisse xxx fangjun", n
s = "äÄ".lower()
n = remove_non_native_characters(s, lang="de")
assert n == "aeae", n
s = "öÖ".lower()
n = remove_non_native_characters(s, lang="de")
assert n == "oeoe", n
s = "üÜ".lower()
n = remove_non_native_characters(s, lang="de")
assert n == "ueue", n
if __name__ == "__main__":
test_remove_non_native_characters()

View File

@ -0,0 +1,17 @@
#!/usr/bin/env python3
from remove_punctuation import remove_punctuation
def test_remove_punctuation():
s = "a,b'c!#"
n = remove_punctuation(s)
assert n == "ab'c", n
s = " ab " # remove leading and trailing spaces
n = remove_punctuation(s)
assert n == "ab", n
if __name__ == "__main__":
test_remove_punctuation()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/validate_bpe_lexicon.py

173
egs/must_c/ST/prepare.sh Executable file
View File

@ -0,0 +1,173 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=10
stage=0
stop_stage=100
version=v1.0
tgt_lang=de
dl_dir=$PWD/download
must_c_dir=$dl_dir/must-c/$version/en-$tgt_lang/data
# We assume dl_dir (download dir) contains the following
# directories and files.
# - $dl_dir/must-c/$version/en-$tgt_lang/data/{dev,train,tst-COMMON,tst-HE}
#
# Please go to https://ict.fbk.eu/must-c-releases/
# to download and untar the dataset if you have not already done this.
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate
# data/lang_bpe_${tgt_lang}_xxx
# data/lang_bpe_${tgt_lang}_yyy
# if the array contains xxx, yyy
vocab_sizes=(
# 5000
# 2000
# 1000
500
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ ! -d $must_c_dir ]; then
log "$must_c_dir does not exist"
exit 1
fi
for d in dev train tst-COMMON tst-HE; do
if [ ! -d $must_c_dir/$d ]; then
log "$must_c_dir/$d does not exist!"
exit 1
fi
done
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download musan"
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to $dl_dir/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare must-c $version manifest for target language $tgt_lang"
mkdir -p data/manifests/$version
if [ ! -e data/manifests/$version/.${tgt_lang}.manifests.done ]; then
lhotse prepare must-c \
-j $nj \
--tgt-lang $tgt_lang \
$dl_dir/must-c/$version/ \
data/manifests/$version/
touch data/manifests/$version/.${tgt_lang}.manifests.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Text normalization for $version with target language $tgt_lang"
if [ ! -f ./data/manifests/$version/.$tgt_lang.norm.done ]; then
./local/preprocess_must_c.py \
--manifest-dir ./data/manifests/$version/ \
--tgt-lang $tgt_lang
touch ./data/manifests/$version/.$tgt_lang.norm.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for $version with target language $tgt_lang"
mkdir -p data/fbank/$version/
if [ ! -e data/fbank/$version/.$tgt_lang.done ]; then
./local/compute_fbank_must_c.py \
--in-dir ./data/manifests/$version/ \
--out-dir ./data/fbank/$version/ \
--tgt-lang $tgt_lang \
--num-jobs $nj
./local/compute_fbank_must_c.py \
--in-dir ./data/manifests/$version/ \
--out-dir ./data/fbank/$version/ \
--tgt-lang $tgt_lang \
--num-jobs $nj \
--perturb-speed 1
touch data/fbank/$version/.$tgt_lang.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang for $version with target language $tgt_lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}/$version/$tgt_lang/
mkdir -p $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
./local/get_text.py ./data/fbank/$version/must_c_feats_en-${tgt_lang}_train.jsonl.gz > $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/words.txt ]; then
./local/get_words.py $lang_dir/transcript_words.txt > $lang_dir/words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
done
fi

1
egs/must_c/ST/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule:
group.add_argument(
"--num-buckets",
type=int,
default=300,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@ -364,7 +364,7 @@ class WenetSpeechAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats

View File

@ -92,7 +92,7 @@ When training with the L subset, the streaming usage:
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64
(4) modified beam search with RNNLM shallow fusion
./pruned_transducer_stateless5/decode.py \
--epoch 35 \
@ -112,8 +112,10 @@ When training with the L subset, the streaming usage:
import argparse
import glob
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -133,7 +135,8 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -307,6 +310,26 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
@ -362,6 +385,7 @@ def decode_one_batch(
lexicon: Lexicon,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
@ -402,14 +426,13 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
@ -448,6 +471,7 @@ def decode_one_batch(
encoder_out=encoder_out,
beam=params.beam_size,
encoder_out_lens=encoder_out_lens,
context_graph=context_graph,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -509,7 +533,12 @@ def decode_one_batch(
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: hyps}
def decode_dataset(
@ -518,6 +547,7 @@ def decode_dataset(
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
@ -567,6 +597,7 @@ def decode_dataset(
lexicon=lexicon,
decoding_graph=decoding_graph,
batch=batch,
context_graph=context_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
@ -646,6 +677,12 @@ def main():
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
)
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
@ -655,6 +692,10 @@ def main():
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-contexts-words"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -684,11 +725,15 @@ def main():
logging.info(f"Device: {device}")
# import pdb; pdb.set_trace()
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
if params.simulate_streaming:
assert (
params.causal_convolution
@ -816,6 +861,19 @@ def main():
else:
decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts_text = []
for line in open(params.context_file).readlines():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -833,15 +891,16 @@ def main():
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl):
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
context_graph=context_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,

View File

@ -23,6 +23,8 @@ from .checkpoint import (
save_checkpoint_with_global_batch_idx,
)
from .context_graph import ContextGraph, ContextState
from .decode import (
get_lattice,
nbest_decoding,

412
icefall/context_graph.py Normal file
View File

@ -0,0 +1,412 @@
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from collections import deque
from typing import Dict, List, Optional, Tuple
class ContextState:
"""The state in ContextGraph"""
def __init__(
self,
id: int,
token: int,
token_score: float,
node_score: float,
local_node_score: float,
is_end: bool,
):
"""Create a ContextState.
Args:
id:
The node id, only for visualization now. A node is in [0, graph.num_nodes).
The id of the root node is always 0.
token:
The token id.
score:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
node_score:
The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc.
local_node_score:
The accumulated bonus from last ``end_node``(node with is_end true)
to current_node, it will be used to calculate the score for fail arc.
Node: The local_node_score of a ``end_node`` is 0.
is_end:
True if current token is the end of a context.
"""
self.id = id
self.token = token
self.token_score = token_score
self.node_score = node_score
self.local_node_score = local_node_score
self.is_end = is_end
self.next = {}
self.fail = None
self.output = None
class ContextGraph:
"""The ContextGraph is modified from Aho-Corasick which is mainly
a Trie with a fail arc for each node.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
of Aho-Corasick algorithm.
A ContextGraph contains some words / phrases that we expect to boost their
scores during decoding. If the substring of a decoded sequence matches the word / phrase
in the ContextGraph, we will give the decoded sequence a bonus to make it survive
beam search.
"""
def __init__(self, context_score: float):
"""Initialize a ContextGraph with the given ``context_score``.
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
Args:
context_score:
The bonus score for each token(note: NOT for each word/phrase, it means longer
word/phrase will have larger bonus score, they have to be matched though).
"""
self.context_score = context_score
self.num_nodes = 0
self.root = ContextState(
id=self.num_nodes,
token=-1,
token_score=0,
node_score=0,
local_node_score=0,
is_end=False,
)
self.root.fail = self.root
def _fill_fail_output(self):
"""This function fills the fail arc for each trie node, it can be computed
in linear time by performing a breadth-first search starting from the root.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
details of the algorithm.
"""
queue = deque()
for token, node in self.root.next.items():
node.fail = self.root
queue.append(node)
while queue:
current_node = queue.popleft()
for token, node in current_node.next.items():
fail = current_node.fail
if token in fail.next:
fail = fail.next[token]
else:
fail = fail.fail
while token not in fail.next:
fail = fail.fail
if fail.token == -1: # root
break
if token in fail.next:
fail = fail.next[token]
node.fail = fail
# fill the output arc
output = node.fail
while not output.is_end:
output = output.fail
if output.token == -1: # root
output = None
break
node.output = output
queue.append(node)
def build(self, token_ids: List[List[int]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
Args:
token_ids:
The given token lists to build the ContextGraph, it is a list of token list,
each token list contains the token ids for a word/phrase. The token id
could be an id of a char (modeling with single Chinese char) or an id
of a BPE (modeling with BPEs).
"""
for tokens in token_ids:
node = self.root
for i, token in enumerate(tokens):
if token not in node.next:
self.num_nodes += 1
is_end = i == len(tokens) - 1
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
token_score=self.context_score,
node_score=node.node_score + self.context_score,
local_node_score=0
if is_end
else (node.local_node_score + self.context_score),
is_end=is_end,
)
node = node.next[token]
self._fill_fail_output()
def forward_one_step(
self, state: ContextState, token: int
) -> Tuple[float, ContextState]:
"""Search the graph with given state and token.
Args:
state:
The given token containing trie node to start.
token:
The given token.
Returns:
Return a tuple of score and next state.
"""
node = None
score = 0
# token matched
if token in state.next:
node = state.next[token]
score = node.token_score
if state.is_end:
score += state.node_score
else:
# token not matched
# We will trace along the fail arc until it matches the token or reaching
# root of the graph.
node = state.fail
while token not in node.next:
node = node.fail
if node.token == -1: # root
break
if token in node.next:
node = node.next[token]
# The score of the fail path
score = node.node_score - state.local_node_score
assert node is not None
matched_score = 0
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize
the matching, the purpose is to subtract the added bonus score for the
state that is not the end of a word/phrase.
Args:
state:
The given state(trie node).
Returns:
Return a tuple of score and next state. If state is the end of a word/phrase
the score is zero, otherwise the score is the score of a implicit fail arc
to root. The next state is always root.
"""
# The score of the fail arc
score = -state.node_score
if state.is_end:
score = 0
return (score, self.root)
def draw(
self,
title: Optional[str] = None,
filename: Optional[str] = "",
symbol_table: Optional[Dict[int, str]] = None,
) -> "Digraph": # noqa
"""Visualize a ContextGraph via graphviz.
Render ContextGraph as an image via graphviz, and return the Digraph object;
and optionally save to file `filename`.
`filename` must have a suffix that graphviz understands, such as
`pdf`, `svg` or `png`.
Note:
You need to install graphviz to use this function::
pip install graphviz
Args:
title:
Title to be displayed in image, e.g. 'A simple FSA example'
filename:
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
'foo.png' (must have a suffix that graphviz understands).
symbol_table:
Map the token ids to symbols.
Returns:
A Diagraph from grahpviz.
"""
try:
import graphviz
except Exception:
print("You cannot use `to_dot` unless the graphviz package is installed.")
raise
graph_attr = {
"rankdir": "LR",
"size": "8.5,11",
"center": "1",
"orientation": "Portrait",
"ranksep": "0.4",
"nodesep": "0.25",
}
if title is not None:
graph_attr["label"] = title
default_node_attr = {
"shape": "circle",
"style": "bold",
"fontsize": "14",
}
final_state_attr = {
"shape": "doublecircle",
"style": "bold",
"fontsize": "14",
}
final_state = -1
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
seen = set()
queue = deque()
queue.append(self.root)
# root id is always 0
dot.node("0", label="0", **default_node_attr)
dot.edge("0", "0", color="red")
seen.add(0)
while len(queue):
current_node = queue.popleft()
for token, node in current_node.next.items():
if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip(
"0"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr)
else:
dot.node(str(node.id), label=label, **default_node_attr)
seen.add(node.id)
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
label = str(token) if symbol_table is None else symbol_table[token]
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
dot.edge(
str(node.id),
str(node.fail.id),
color="red",
)
if node.output is not None:
dot.edge(
str(node.id),
str(node.output.id),
color="green",
)
queue.append(node)
if filename:
_, extension = os.path.splitext(filename)
if extension == "" or extension[0] != ".":
raise ValueError(
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
filename
)
)
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
temp_fn = dot.render(
filename="temp",
directory=tmp_dir,
format=extension[1:],
cleanup=True,
)
shutil.move(temp_fn, filename)
return dot
if __name__ == "__main__":
contexts_str = [
"S",
"HE",
"SHE",
"SHELL",
"HIS",
"HERS",
"HELLO",
"THIS",
"THEM",
]
contexts = []
for s in contexts_str:
contexts.append([ord(x) for x in s])
context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
symbol_table = {}
for contexts in contexts_str:
for s in contexts:
symbol_table[ord(s)] = s
context_graph.draw(
title="Graph for: " + " / ".join(contexts_str),
filename="context_graph.pdf",
symbol_table=symbol_table,
)
queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
for query, expected_score in queries.items():
total_scores = 0
state = context_graph.root
for q in query:
score, state = context_graph.forward_one_step(state, ord(q))
total_scores += score
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
total_scores += score
assert total_scores == expected_score, (
total_scores,
expected_score,
query,
)

View File

@ -28,6 +28,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from shutil import copyfile
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
import k2
@ -1881,3 +1882,20 @@ def is_cjk(character):
]
]
)
def symlink_or_copy(exp_dir: Path, src: str, dst: str):
"""
In the experiment directory, create a symlink pointing to src named dst.
If symlink creation fails (Windows?), fall back to copyfile."""
dir_fd = os.open(exp_dir, os.O_RDONLY)
try:
os.remove(dst, dir_fd=dir_fd)
except FileNotFoundError:
pass
try:
os.symlink(src=src, dst=dst, dir_fd=dir_fd)
except OSError:
copyfile(src=exp_dir / src, dst=exp_dir / dst)
os.close(dir_fd)