mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'master' into wenetspeech
This commit is contained in:
commit
a1b12cf4e9
@ -58,6 +58,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -76,6 +77,8 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -211,6 +214,26 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -222,6 +245,7 @@ def decode_one_batch(
|
||||
token_table: k2.SymbolTable,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -285,6 +309,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
else:
|
||||
hyp_tokens = []
|
||||
@ -324,7 +349,12 @@ def decode_one_batch(
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -333,6 +363,7 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -377,6 +408,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
token_table=token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -407,16 +439,17 @@ def save_results(
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
|
||||
store_transcripts(filename=recog_path, texts=results_char)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
@ -457,6 +490,12 @@ def main():
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -470,6 +509,10 @@ def main():
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -490,6 +533,11 @@ def main():
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -586,6 +634,19 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -608,6 +669,7 @@ def main():
|
||||
model=model,
|
||||
token_table=lexicon.token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
||||
@ -577,9 +577,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -806,13 +803,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -859,7 +850,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -872,7 +862,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -580,9 +580,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -809,13 +806,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -862,7 +853,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -875,7 +865,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -567,9 +567,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -799,13 +796,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -852,7 +843,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -865,7 +855,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -512,9 +512,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -725,13 +722,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
# print(batch["supervisions"])
|
||||
@ -774,7 +765,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -787,7 +777,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -554,9 +554,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -779,13 +776,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -832,7 +823,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -845,7 +835,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -549,9 +549,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -770,13 +767,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -823,7 +814,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -836,7 +826,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -63,6 +63,14 @@ log() {
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if ! command -v ffmpeg &> /dev/null; then
|
||||
echo "This dataset requires ffmpeg"
|
||||
echo "Please install ffmpeg first"
|
||||
echo ""
|
||||
echo " sudo apt-get install ffmpeg"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
|
||||
@ -567,9 +567,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -799,13 +796,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -852,7 +843,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -865,7 +855,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -606,9 +606,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -835,13 +832,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -889,7 +880,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -902,7 +892,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -607,9 +607,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -836,13 +833,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -890,7 +881,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -903,7 +893,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -462,9 +462,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -674,13 +671,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -712,7 +703,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -725,7 +715,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -410,9 +410,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -675,13 +672,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
batch_name = batch["supervisions"]["uttid"]
|
||||
@ -736,7 +727,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -749,7 +739,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -550,9 +550,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -771,13 +768,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -819,7 +810,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -832,7 +822,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -550,9 +550,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -771,13 +768,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -819,7 +810,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -832,7 +822,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -552,9 +552,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -773,13 +770,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -821,7 +812,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -834,7 +824,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -107,7 +107,7 @@ fi
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
# to $dl_dir/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
|
||||
@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=16
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# Split data/${lang}set to this number of pieces
|
||||
# This is to avoid OOM during feature extraction.
|
||||
num_splits=1000
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/$release/$lang
|
||||
# This directory contains the following files downloaded from
|
||||
# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz
|
||||
#
|
||||
# - clips
|
||||
# - dev.tsv
|
||||
# - invalidated.tsv
|
||||
# - other.tsv
|
||||
# - reported.tsv
|
||||
# - test.tsv
|
||||
# - train.tsv
|
||||
# - validated.tsv
|
||||
|
||||
dl_dir=$PWD/download
|
||||
release=cv-corpus-13.0-2023-03-09
|
||||
lang=en
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data/${lang}".
|
||||
# You can safely remove "data/${lang}" and rerun this script to regenerate it.
|
||||
mkdir -p data/${lang}
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/$release,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/$release $dl_dir/$release
|
||||
#
|
||||
if [ ! -d $dl_dir/$release/$lang/clips ]; then
|
||||
lhotse download commonvoice --languages $lang --release $release $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare CommonVoice manifest"
|
||||
# We assume that you have downloaded the CommonVoice corpus
|
||||
# to $dl_dir/$release
|
||||
mkdir -p data/${lang}/manifests
|
||||
if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then
|
||||
lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests
|
||||
touch data/${lang}/manifests/.cv-${lang}.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Preprocess CommonVoice manifest"
|
||||
if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then
|
||||
./local/preprocess_commonvoice.py --language $lang
|
||||
touch data/${lang}/fbank/.preprocess_complete
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank for dev and test subsets of CommonVoice"
|
||||
mkdir -p data/${lang}/fbank
|
||||
if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then
|
||||
./local/compute_fbank_commonvoice_dev_test.py --language $lang
|
||||
touch data/${lang}/fbank/.cv-${lang}_dev_test.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Split train subset into ${num_splits} pieces"
|
||||
split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits}
|
||||
if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then
|
||||
lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir
|
||||
touch $split_dir/.cv-${lang}_train_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute features for train subset of CommonVoice"
|
||||
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
|
||||
./local/compute_fbank_commonvoice_splits.py \
|
||||
--num-workers $nj \
|
||||
--batch-duration 600 \
|
||||
--start 0 \
|
||||
--num-splits $num_splits \
|
||||
--language $lang
|
||||
touch data/${lang}/fbank/.cv-${lang}_train.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Combine features for train"
|
||||
if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then
|
||||
pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz")
|
||||
lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz
|
||||
fi
|
||||
fi
|
||||
@ -1,159 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=15
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/GigaSpeech
|
||||
# You can find audio, dict, GigaSpeech.json inside it.
|
||||
# You can apply for the download credentials by following
|
||||
# https://github.com/SpeechColab/GigaSpeech#download
|
||||
|
||||
# Number of hours for GigaSpeech subsets
|
||||
# XL 10k hours
|
||||
# L 2.5k hours
|
||||
# M 1k hours
|
||||
# S 250 hours
|
||||
# XS 10 hours
|
||||
# DEV 12 hours
|
||||
# Test 40 hours
|
||||
|
||||
# Split XL subset to this number of pieces
|
||||
# This is to avoid OOM during feature extraction.
|
||||
num_splits=2000
|
||||
# We use lazy split from lhotse.
|
||||
# The XL subset (10k hours) contains 37956 cuts without speed perturbing.
|
||||
# We want to split it into 2000 splits, so each split
|
||||
# contains about 37956 / 2000 = 19 cuts. As a result, there will be 1998 splits.
|
||||
chunk_size=19 # number of cuts in each split. The last split may contain fewer cuts.
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
[ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech
|
||||
|
||||
# If you have pre-downloaded it to /path/to/GigaSpeech,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech
|
||||
#
|
||||
if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then
|
||||
# Check credentials.
|
||||
if [ ! -f $dl_dir/password ]; then
|
||||
echo -n "$0: Please apply for the download credentials by following"
|
||||
echo -n "https://github.com/SpeechColab/GigaSpeech#dataset-download"
|
||||
echo " and save it to $dl_dir/password."
|
||||
exit 1;
|
||||
fi
|
||||
PASSWORD=`cat $dl_dir/password 2>/dev/null`
|
||||
if [ -z "$PASSWORD" ]; then
|
||||
echo "$0: Error, $dl_dir/password is empty."
|
||||
exit 1;
|
||||
fi
|
||||
PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1`
|
||||
if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then
|
||||
echo "$0: Error, invalid $dl_dir/password."
|
||||
exit 1;
|
||||
fi
|
||||
# Download XL, DEV and TEST sets by default.
|
||||
lhotse download gigaspeech \
|
||||
--subset XL \
|
||||
--subset L \
|
||||
--subset M \
|
||||
--subset S \
|
||||
--subset XS \
|
||||
--subset DEV \
|
||||
--subset TEST \
|
||||
--host tsinghua \
|
||||
$dl_dir/password $dl_dir/GigaSpeech
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare GigaSpeech manifest (may take 30 minutes)"
|
||||
# We assume that you have downloaded the GigaSpeech corpus
|
||||
# to $dl_dir/GigaSpeech
|
||||
if [ ! -f data/manifests/.gigaspeech.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare gigaspeech \
|
||||
--subset XL \
|
||||
--subset L \
|
||||
--subset M \
|
||||
--subset S \
|
||||
--subset XS \
|
||||
--subset DEV \
|
||||
--subset TEST \
|
||||
-j $nj \
|
||||
$dl_dir/GigaSpeech data/manifests
|
||||
touch data/manifests/.gigaspeech.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Preprocess GigaSpeech manifest"
|
||||
if [ ! -f data/fbank/.gigaspeech_preprocess.done ]; then
|
||||
log "It may take 2 hours for this stage"
|
||||
./local/preprocess_gigaspeech.py
|
||||
touch data/fbank/.gigaspeech_preprocess.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
|
||||
if [ ! -f data/fbank/.gigaspeech_dev_test.done ]; then
|
||||
./local/compute_fbank_gigaspeech_dev_test.py
|
||||
touch data/fbank/.gigaspeech_dev_test.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Split XL subset into ${num_splits} pieces"
|
||||
split_dir=data/fbank/gigaspeech_XL_split_${num_splits}
|
||||
if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then
|
||||
lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $chunk_size
|
||||
touch $split_dir/.gigaspeech_XL_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute features for XL"
|
||||
# Note: The script supports --start and --stop options.
|
||||
# You can use several machines to compute the features in parallel.
|
||||
if [ ! -f data/fbank/.gigaspeech_XL.done ]; then
|
||||
./local/compute_fbank_gigaspeech_splits.py \
|
||||
--num-workers $nj \
|
||||
--batch-duration 600 \
|
||||
--num-splits $num_splits
|
||||
touch data/fbank/.gigaspeech_XL.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Combine features for XL (may take 15 hours)"
|
||||
if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/gigaspeech_XL_split_${num_splits} -name "gigaspeech_cuts_XL.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz
|
||||
fi
|
||||
fi
|
||||
@ -1,330 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=16
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/LibriSpeech
|
||||
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
|
||||
# You can download them from https://www.openslr.org/12
|
||||
#
|
||||
# - $dl_dir/lm
|
||||
# This directory contains the following files downloaded from
|
||||
# http://www.openslr.org/resources/11
|
||||
#
|
||||
# - 3-gram.pruned.1e-7.arpa.gz
|
||||
# - 3-gram.pruned.1e-7.arpa
|
||||
# - 4-gram.arpa.gz
|
||||
# - 4-gram.arpa
|
||||
# - librispeech-vocab.txt
|
||||
# - librispeech-lexicon.txt
|
||||
# - librispeech-lm-norm.txt.gz
|
||||
#
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
|
||||
# Split all dataset to this number of pieces and mix each dataset pieces
|
||||
# into multidataset pieces with shuffling.
|
||||
num_splits=1998
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 5000
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
# multidataset list.
|
||||
# LibriSpeech and musan are required.
|
||||
# The others are optional.
|
||||
multidataset=(
|
||||
"gigaspeech",
|
||||
"commonvoice",
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
log "Dataset: LibriSpeech and musan"
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: Download LM"
|
||||
mkdir -p $dl_dir/lm
|
||||
if [ ! -e $dl_dir/lm/.done ]; then
|
||||
./local/download_lm.py --out-dir=$dl_dir/lm
|
||||
touch $dl_dir/lm/.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/LibriSpeech,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
|
||||
#
|
||||
if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
|
||||
lhotse download librispeech --full $dl_dir
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare LibriSpeech manifest"
|
||||
# We assume that you have downloaded the LibriSpeech corpus
|
||||
# to $dl_dir/LibriSpeech
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.librispeech.done ]; then
|
||||
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
|
||||
touch data/manifests/.librispeech.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank for librispeech"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.librispeech.done ]; then
|
||||
./local/compute_fbank_librispeech.py --perturb-speed False
|
||||
touch data/fbank/.librispeech.done
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then
|
||||
cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \
|
||||
shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ ! -e data/fbank/.librispeech-validated.done ]; then
|
||||
log "Validating data/fbank for LibriSpeech"
|
||||
parts=(
|
||||
train-clean-100
|
||||
train-clean-360
|
||||
train-other-500
|
||||
test-clean
|
||||
test-other
|
||||
dev-clean
|
||||
dev-other
|
||||
)
|
||||
for part in ${parts[@]}; do
|
||||
python3 ./local/validate_manifest.py \
|
||||
data/fbank/librispeech_cuts_${part}.jsonl.gz
|
||||
done
|
||||
touch data/fbank/.librispeech-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.musan.done ]; then
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
lang_dir=data/lang_phone
|
||||
mkdir -p $lang_dir
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
||||
sort | uniq > $lang_dir/lexicon.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_dir
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L.fst ]; then
|
||||
log "Converting L.pt to L.fst"
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
--olabels aux_labels \
|
||||
$lang_dir/L.pt \
|
||||
$lang_dir/L.fst
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.fst ]; then
|
||||
log "Converting L_disambig.pt to L_disambig.fst"
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
--olabels aux_labels \
|
||||
$lang_dir/L_disambig.pt \
|
||||
$lang_dir/L_disambig.fst
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare BPE based lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp data/lang_phone/words.txt $lang_dir
|
||||
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
log "Generate data for BPE training"
|
||||
files=$(
|
||||
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
||||
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
|
||||
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
|
||||
)
|
||||
for f in ${files[@]}; do
|
||||
cat $f | cut -d " " -f 2-
|
||||
done > $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
|
||||
log "Validating $lang_dir/lexicon.txt"
|
||||
./local/validate_bpe_lexicon.py \
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--bpe-model $lang_dir/bpe.model
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L.fst ]; then
|
||||
log "Converting L.pt to L.fst"
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
--olabels aux_labels \
|
||||
$lang_dir/L.pt \
|
||||
$lang_dir/L.fst
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.fst ]; then
|
||||
log "Converting L_disambig.pt to L_disambig.fst"
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
--olabels aux_labels \
|
||||
$lang_dir/L_disambig.pt \
|
||||
$lang_dir/L_disambig.fst
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang_phone/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
|
||||
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
|
||||
# It is used for LM rescoring
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang_phone/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=4 \
|
||||
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Compile HLG"
|
||||
./local/compile_hlg.py --lang-dir data/lang_phone
|
||||
|
||||
# Note If ./local/compile_hlg.py throws OOM,
|
||||
# please switch to the following command
|
||||
#
|
||||
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
./local/compile_hlg.py --lang-dir $lang_dir
|
||||
|
||||
# Note If ./local/compile_hlg.py throws OOM,
|
||||
# please switch to the following command
|
||||
#
|
||||
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
# Compile LG for RNN-T fast_beam_search decoding
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Compile LG"
|
||||
./local/compile_lg.py --lang-dir data/lang_phone
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Prepare the other datasets"
|
||||
# GigaSpeech
|
||||
if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then
|
||||
log "Dataset: GigaSpeech"
|
||||
./prepare_giga_speech.sh --stop_stage 5
|
||||
fi
|
||||
|
||||
# CommonVoice
|
||||
if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then
|
||||
log "Dataset: CommonVoice"
|
||||
./prepare_common_voice.sh
|
||||
fi
|
||||
fi
|
||||
@ -444,9 +444,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -649,13 +646,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -686,7 +677,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -698,7 +688,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -487,9 +487,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -692,13 +689,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -738,7 +729,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -750,7 +740,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -24,7 +24,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
@ -785,6 +785,9 @@ class Hypothesis:
|
||||
# N-gram LM state
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
# Context graph state
|
||||
context_state: Optional[ContextState] = None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
@ -937,6 +940,7 @@ def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
blank_penalty: float = 0.0,
|
||||
@ -989,6 +993,7 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
@ -1011,6 +1016,7 @@ def modified_beam_search(
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
@ -1071,21 +1077,51 @@ def modified_beam_search(
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
if context_graph is not None:
|
||||
(
|
||||
context_score,
|
||||
new_context_state,
|
||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k] + context_score
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
|
||||
# finalize context_state, if the matched contexts do not reach final state
|
||||
# we need to add the score on the corresponding backoff arc
|
||||
if context_graph is not None:
|
||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||
for i, hyps in enumerate(B):
|
||||
for hyp in list(hyps):
|
||||
context_score, new_context_state = context_graph.finalize(
|
||||
hyp.context_state
|
||||
)
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
log_prob=hyp.log_prob + context_score,
|
||||
timestamp=hyp.timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
)
|
||||
B = finalized_B
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
|
||||
@ -99,7 +99,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless3/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
@ -125,6 +125,7 @@ For example:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -146,6 +147,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -353,6 +355,27 @@ def get_parser():
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -365,6 +388,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -494,6 +518,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
@ -548,7 +573,12 @@ def decode_one_batch(
|
||||
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -558,6 +588,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -622,6 +653,7 @@ def decode_dataset(
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
@ -728,6 +760,12 @@ def main():
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -750,6 +788,10 @@ def main():
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-context-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -881,6 +923,18 @@ def main():
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(sp.encode(contexts))
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -905,6 +959,7 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
||||
@ -78,7 +78,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -115,7 +115,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--exp-dir pruned_transducer_stateless4/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
@ -37,7 +37,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--exp-dir pruned_transducer_stateless4/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
@ -195,7 +195,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -296,7 +296,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
||||
@ -78,7 +78,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -115,7 +115,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
python ./pruned_transducer_stateless5/test_model.py
|
||||
"""
|
||||
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
@ -328,7 +328,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -20,23 +20,23 @@
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
./pruned_transducer_stateless6/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless6/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless2/decode.py`,
|
||||
To use the generated file with `pruned_transducer_stateless6/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
./pruned_transducer_stateless6/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless6/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 100 \
|
||||
@ -65,7 +65,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -91,7 +91,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless6/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
|
||||
@ -267,7 +267,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -94,10 +94,10 @@ Usage:
|
||||
--max-states 64
|
||||
|
||||
(8) modified beam search with RNNLM shallow fusion
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--beam-size 4 \
|
||||
@ -110,11 +110,11 @@ Usage:
|
||||
--rnn-lm-tie-weights 1
|
||||
|
||||
(9) modified beam search with LM shallow fusion + LODR
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--max-duration 600 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--decoding-method modified_beam_search_LODR \
|
||||
--beam-size 4 \
|
||||
--lm-type rnn \
|
||||
|
||||
@ -79,7 +79,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -116,7 +116,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
|
||||
@ -389,7 +389,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -620,9 +620,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -896,13 +893,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -953,7 +944,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -966,7 +956,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import lhotse
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
class MultiDataset:
|
||||
def __init__(self, manifest_dir: str, cv_manifest_dir: str):
|
||||
"""
|
||||
Args:
|
||||
manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
|
||||
- librispeech_cuts_train-all-shuf.jsonl.gz
|
||||
- gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz
|
||||
|
||||
cv_manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
|
||||
- cv-en_cuts_train.jsonl.gz
|
||||
"""
|
||||
self.manifest_dir = Path(manifest_dir)
|
||||
self.cv_manifest_dir = Path(cv_manifest_dir)
|
||||
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# LibriSpeech
|
||||
logging.info(f"Loading LibriSpeech in lazy mode")
|
||||
librispeech_cuts = load_manifest_lazy(
|
||||
self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
# GigaSpeech
|
||||
filenames = glob.glob(
|
||||
f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz"
|
||||
)
|
||||
|
||||
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
|
||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||
|
||||
sorted_filenames = [f[1] for f in idx_filenames]
|
||||
|
||||
logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode")
|
||||
|
||||
gigaspeech_cuts = lhotse.combine(
|
||||
lhotse.load_manifest_lazy(p) for p in sorted_filenames
|
||||
)
|
||||
|
||||
# CommonVoice
|
||||
logging.info(f"Loading CommonVoice in lazy mode")
|
||||
commonvoice_cuts = load_manifest_lazy(
|
||||
self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts)
|
||||
@ -50,7 +50,6 @@ import copy
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
@ -66,7 +65,6 @@ from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from multidataset import MultiDataset
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
@ -90,6 +88,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
symlink_or_copy,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -341,7 +340,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -376,13 +375,6 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-multidataset",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use multidataset to train.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -578,9 +570,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -612,7 +601,8 @@ def save_checkpoint(
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
epoch_basename = f"epoch-{params.cur_epoch}.pt"
|
||||
filename = params.exp_dir / epoch_basename
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
@ -626,12 +616,14 @@ def save_checkpoint(
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
symlink_or_copy(
|
||||
exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
|
||||
)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
symlink_or_copy(
|
||||
exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
|
||||
)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
@ -811,13 +803,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -864,7 +850,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -877,7 +862,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
@ -1053,16 +1037,12 @@ def run(rank, world_size, args):
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.use_multidataset:
|
||||
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
|
||||
train_cuts = multidataset.train_cuts()
|
||||
if params.mini_libri:
|
||||
train_cuts = librispeech.train_clean_5_cuts()
|
||||
elif params.full_libri:
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
else:
|
||||
if params.mini_libri:
|
||||
train_cuts = librispeech.train_clean_5_cuts()
|
||||
elif params.full_libri:
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
@ -1118,7 +1098,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.use_multidataset and not params.print_diagnostics:
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
|
||||
@ -346,7 +346,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -577,9 +577,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -830,13 +827,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -883,7 +874,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -896,7 +886,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -342,7 +342,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -570,9 +570,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -819,13 +816,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -872,7 +863,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -885,7 +875,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -90,7 +90,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
||||
@ -88,7 +88,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -65,7 +65,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
|
||||
@ -77,7 +77,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -114,7 +114,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
@ -355,7 +355,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -586,9 +586,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -807,13 +804,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -860,7 +851,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -873,7 +863,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -355,7 +355,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -587,9 +587,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -808,13 +805,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -861,7 +852,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -874,7 +864,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
0
egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py
Normal file → Executable file
0
egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py
Normal file → Executable file
@ -88,7 +88,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
||||
4
egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py
Normal file → Executable file
4
egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py
Normal file → Executable file
@ -78,7 +78,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -115,7 +115,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless7_streaming_multi/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
@ -366,7 +366,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -604,9 +604,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -921,7 +918,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -934,7 +930,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -348,7 +348,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1127,7 +1127,16 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
|
||||
@ -627,14 +627,6 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
|
||||
@ -654,20 +654,6 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
try:
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
except TypeError as e:
|
||||
# You can ignore this error as previous versions of Lhotse work fine
|
||||
# for the above code. In recent versions of Lhotse, it uses
|
||||
# lazy filter, producing cutsets that don't have the __len__ method
|
||||
logging.info(str(e))
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
|
||||
@ -642,20 +642,6 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
try:
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
except TypeError as e:
|
||||
# You can ignore this error as previous versions of Lhotse work fine
|
||||
# for the above code. In recent versions of Lhotse, it uses
|
||||
# lazy filter, producing cutsets that don't have the __len__ method
|
||||
logging.info(str(e))
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
|
||||
775
egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Executable file
775
egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Executable file
@ -0,0 +1,775 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx-streaming.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 64
|
||||
|
||||
The --chunk-size in training is "16,32,64,-1", so we select one of them
|
||||
(excluding -1) during streaming export. The same applies to `--left-context`,
|
||||
whose value is "64,128,256,-1".
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool, make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(
|
||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder_proj = encoder_proj
|
||||
self.chunk_size = encoder.chunk_size[0]
|
||||
self.left_context_len = encoder.left_context_frames[0]
|
||||
self.pad_length = 7 + 2 * 3
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
N = x.size(0)
|
||||
T = self.chunk_size * 2 + self.pad_length
|
||||
x_lens = torch.tensor([T] * N, device=x.device)
|
||||
left_context_len = self.left_context_len
|
||||
|
||||
cached_embed_left_pad = states[-2]
|
||||
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
)
|
||||
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
|
||||
x = x.permute(1, 0, 2)
|
||||
encoder_states = states[:-2]
|
||||
logging.info(f"len_encoder_states={len(encoder_states)}")
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_encoder_states,
|
||||
) = self.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=encoder_states,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2)
|
||||
encoder_out = self.encoder_proj(encoder_out)
|
||||
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||
|
||||
new_states = new_encoder_states + [
|
||||
new_cached_embed_left_pad,
|
||||
new_processed_lens,
|
||||
]
|
||||
|
||||
return encoder_out, new_states
|
||||
|
||||
def get_init_states(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||
states[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
"""
|
||||
states = self.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
encoder_model.encoder.__class__.forward = (
|
||||
encoder_model.encoder.__class__.streaming_forward
|
||||
)
|
||||
|
||||
decode_chunk_len = encoder_model.chunk_size * 2
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
T = decode_chunk_len + encoder_model.pad_length
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
init_state = encoder_model.get_init_states()
|
||||
num_encoders = len(encoder_model.encoder.encoder_dim)
|
||||
logging.info(f"num_encoders: {num_encoders}")
|
||||
logging.info(f"len(init_state): {len(init_state)}")
|
||||
|
||||
inputs = {}
|
||||
input_names = ["x"]
|
||||
|
||||
outputs = {}
|
||||
output_names = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f"cached_key_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[0].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f"cached_nonlin_attn_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[1].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val1_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[2].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val2_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[3].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv1_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[4].shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv2_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[5].shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
|
||||
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
|
||||
ds = encoder_model.encoder.downsampling_factor
|
||||
left_context_len = encoder_model.left_context_len
|
||||
left_context_len = [left_context_len // k for k in ds]
|
||||
left_context_len = ",".join(map(str, left_context_len))
|
||||
query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim))
|
||||
value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim))
|
||||
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer2",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "streaming zipformer2",
|
||||
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||
"T": str(T), # 32+7+2*3=45
|
||||
"num_encoder_layers": num_encoder_layers,
|
||||
"encoder_dims": encoder_dims,
|
||||
"cnn_module_kernels": cnn_module_kernels,
|
||||
"left_context_len": left_context_len,
|
||||
"query_head_dims": query_head_dims,
|
||||
"value_head_dims": value_head_dims,
|
||||
"num_heads": num_heads,
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
for i in range(len(init_state[:-2]) // 6):
|
||||
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
embed_states = init_state[-2]
|
||||
name = "embed_states"
|
||||
logging.info(f"{name}.shape: {embed_states.shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
processed_lens = init_state[-1]
|
||||
name = "processed_lens"
|
||||
logging.info(f"{name}.shape: {processed_lens.shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
logging.info(inputs)
|
||||
logging.info(outputs)
|
||||
logging.info(input_names)
|
||||
logging.info(output_names)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, init_state),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes={
|
||||
"x": {0: "N"},
|
||||
"encoder_out": {0: "N"},
|
||||
**inputs,
|
||||
**outputs,
|
||||
},
|
||||
)
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_embed=model.encoder_embed,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
624
egs/librispeech/ASR/zipformer/export-onnx.py
Executable file
624
egs/librispeech/ASR/zipformer/export-onnx.py
Executable file
@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
\
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||
use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool, make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(
|
||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of Zipformer.forward
|
||||
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2)
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2)
|
||||
encoder_out = self.encoder_proj(encoder_out)
|
||||
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer2",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "non-streaming zipformer2",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_embed=model.encoder_embed,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
@ -49,7 +49,7 @@ class Transducer(nn.Module):
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
|
||||
544
egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
Executable file
544
egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
Executable file
@ -0,0 +1,544 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script loads ONNX models exported by ./export-onnx-streaming.py
|
||||
and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx-streaming.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 64
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file with the exported ONNX models
|
||||
|
||||
./zipformer/onnx_pretrained-streaming.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
Note: Even though this script only supports decoding a single file,
|
||||
the exported ONNX models do support batch processing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
def init_encoder_states(self, batch_size: int = 1):
|
||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
logging.info(f"encoder_meta={encoder_meta}")
|
||||
|
||||
model_type = encoder_meta["model_type"]
|
||||
assert model_type == "zipformer2", model_type
|
||||
|
||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||
T = int(encoder_meta["T"])
|
||||
|
||||
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
||||
encoder_dims = encoder_meta["encoder_dims"]
|
||||
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
|
||||
left_context_len = encoder_meta["left_context_len"]
|
||||
query_head_dims = encoder_meta["query_head_dims"]
|
||||
value_head_dims = encoder_meta["value_head_dims"]
|
||||
num_heads = encoder_meta["num_heads"]
|
||||
|
||||
def to_int_list(s):
|
||||
return list(map(int, s.split(",")))
|
||||
|
||||
num_encoder_layers = to_int_list(num_encoder_layers)
|
||||
encoder_dims = to_int_list(encoder_dims)
|
||||
cnn_module_kernels = to_int_list(cnn_module_kernels)
|
||||
left_context_len = to_int_list(left_context_len)
|
||||
query_head_dims = to_int_list(query_head_dims)
|
||||
value_head_dims = to_int_list(value_head_dims)
|
||||
num_heads = to_int_list(num_heads)
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||
logging.info(f"encoder_dims: {encoder_dims}")
|
||||
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
|
||||
logging.info(f"left_context_len: {left_context_len}")
|
||||
logging.info(f"query_head_dims: {query_head_dims}")
|
||||
logging.info(f"value_head_dims: {value_head_dims}")
|
||||
logging.info(f"num_heads: {num_heads}")
|
||||
|
||||
num_encoders = len(num_encoder_layers)
|
||||
|
||||
self.states = []
|
||||
for i in range(num_encoders):
|
||||
num_layers = num_encoder_layers[i]
|
||||
key_dim = query_head_dims[i] * num_heads[i]
|
||||
embed_dim = encoder_dims[i]
|
||||
nonlin_attn_head_dim = 3 * embed_dim // 4
|
||||
value_dim = value_head_dims[i] * num_heads[i]
|
||||
conv_left_pad = cnn_module_kernels[i] // 2
|
||||
|
||||
for layer in range(num_layers):
|
||||
cached_key = torch.zeros(
|
||||
left_context_len[i], batch_size, key_dim
|
||||
).numpy()
|
||||
cached_nonlin_attn = torch.zeros(
|
||||
1, batch_size, left_context_len[i], nonlin_attn_head_dim
|
||||
).numpy()
|
||||
cached_val1 = torch.zeros(
|
||||
left_context_len[i], batch_size, value_dim
|
||||
).numpy()
|
||||
cached_val2 = torch.zeros(
|
||||
left_context_len[i], batch_size, value_dim
|
||||
).numpy()
|
||||
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||
self.states += [
|
||||
cached_key,
|
||||
cached_nonlin_attn,
|
||||
cached_val1,
|
||||
cached_val2,
|
||||
cached_conv1,
|
||||
cached_conv2,
|
||||
]
|
||||
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
|
||||
self.states.append(embed_states)
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
|
||||
self.states.append(processed_lens)
|
||||
|
||||
self.num_encoders = num_encoders
|
||||
|
||||
self.segment = T
|
||||
self.offset = decode_chunk_len
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def _build_encoder_input_output(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||
encoder_input = {"x": x.numpy()}
|
||||
encoder_output = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f"cached_key_{i}"
|
||||
encoder_input[name] = tensors[0]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f"cached_nonlin_attn_{i}"
|
||||
encoder_input[name] = tensors[1]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val1_{i}"
|
||||
encoder_input[name] = tensors[2]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val2_{i}"
|
||||
encoder_input[name] = tensors[3]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv1_{i}"
|
||||
encoder_input[name] = tensors[4]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv2_{i}"
|
||||
encoder_input[name] = tensors[5]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
for i in range(len(self.states[:-2]) // 6):
|
||||
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
name = "embed_states"
|
||||
embed_states = self.states[-2]
|
||||
encoder_input[name] = embed_states
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
name = "processed_lens"
|
||||
processed_lens = self.states[-1]
|
||||
encoder_input[name] = processed_lens
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
return encoder_input, encoder_output
|
||||
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
self.states = states
|
||||
|
||||
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
"""
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||
|
||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
context_size: int,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (1, T, joiner_dim)
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
decoder_out:
|
||||
Optional. Decoder output of the previous chunk.
|
||||
hyp:
|
||||
Decoding results for previous chunks.
|
||||
Returns:
|
||||
Return the decoded results so far.
|
||||
"""
|
||||
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert hyp is not None, hyp
|
||||
|
||||
encoder_out = encoder_out.squeeze(0)
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t : t + 1]
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
waves = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
wave_samples = torch.cat([waves, tail_padding])
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.segment
|
||||
offset = model.offset
|
||||
|
||||
context_size = model.context_size
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
chunk = int(1 * sample_rate) # 1 second
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
frames = frames.unsqueeze(0)
|
||||
encoder_out = model.run_encoder(frames)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model,
|
||||
encoder_out,
|
||||
context_size,
|
||||
decoder_out,
|
||||
hyp,
|
||||
)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(text)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/onnx_pretrained.py
|
||||
@ -26,6 +26,18 @@ import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
||||
# 14 is not supported. Please feel free to request support or submit
|
||||
# a pull request on PyTorch GitHub.
|
||||
#
|
||||
# The following function is to solve the above error when exporting
|
||||
# models to ONNX via torch.jit.trace()
|
||||
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
return torch.logaddexp(x, y)
|
||||
else:
|
||||
return (x.exp() + y.exp()).log()
|
||||
|
||||
class PiecewiseLinear(object):
|
||||
"""
|
||||
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
||||
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
|
||||
|
||||
def __float__(self):
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return float(self.default)
|
||||
else:
|
||||
ans = self.schedule(self.batch_count)
|
||||
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
def softmax(x: Tensor, dim: int):
|
||||
if not x.requires_grad or torch.jit.is_scripting():
|
||||
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x.softmax(dim=dim)
|
||||
|
||||
return SoftmaxFunction.apply(x, dim)
|
||||
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return x
|
||||
return scale_grad(x, self.alpha)
|
||||
|
||||
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
|
||||
|
||||
|
||||
def _no_op(x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting()):
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
# a no-op function that will have a node in the autograd graph,
|
||||
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-L activation.
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
if not x.requires_grad:
|
||||
return k2.swoosh_l_forward(x)
|
||||
else:
|
||||
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-R activation.
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
if not x.requires_grad:
|
||||
return k2.swoosh_r_forward(x)
|
||||
else:
|
||||
|
||||
@ -27,6 +27,7 @@ from typing import List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
||||
from zipformer import CompactRelPositionalEncoding
|
||||
|
||||
|
||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
|
||||
model: nn.Module,
|
||||
inplace: bool = False,
|
||||
is_pnnx: bool = False,
|
||||
is_onnx: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
is_pnnx:
|
||||
True if we are going to export the model for PNNX.
|
||||
is_onnx:
|
||||
True if we are going to export the model for ONNX.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
||||
d[name] = nn.Identity()
|
||||
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
||||
# We want to recreate the positional encoding vector when
|
||||
# the input changes, so we have to use torch.jit.script()
|
||||
# to replace torch.jit.trace()
|
||||
d[name] = torch.jit.script(m)
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
|
||||
@ -81,7 +81,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return self.forward_internal(x)
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
|
||||
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.out_norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
x_lens = (x_lens - 7) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
|
||||
@ -62,20 +62,20 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from zipformer import Zipformer2
|
||||
from scaling import ScheduledFloat
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from subsampling import Conv2dSubsampling
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
@ -84,40 +84,38 @@ from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
get_parameter_groups_with_lrs
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def get_adjusted_batch_count(
|
||||
params: AttributeDict) -> float:
|
||||
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||
# returns the number of batches we would have used so far if we had used the reference
|
||||
# duration. This is for purposes of set_batch_count().
|
||||
return (params.batch_idx_train * (params.max_duration * params.world_size) /
|
||||
params.ref_duration)
|
||||
return (
|
||||
params.batch_idx_train
|
||||
* (params.max_duration * params.world_size)
|
||||
/ params.ref_duration
|
||||
)
|
||||
|
||||
|
||||
def set_batch_count(
|
||||
model: Union[nn.Module, DDP], batch_count: float
|
||||
) -> None:
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'batch_count'):
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
if hasattr(module, 'name'):
|
||||
if hasattr(module, "name"):
|
||||
module.name = name
|
||||
|
||||
|
||||
@ -154,35 +152,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
default="192,256,384,512,384,256",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
type=str,
|
||||
default="32",
|
||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--value-head-dim",
|
||||
type=str,
|
||||
default="12",
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-head-dim",
|
||||
type=str,
|
||||
default="4",
|
||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-dim",
|
||||
type=int,
|
||||
default="48",
|
||||
help="Positional-encoding embedding dimension"
|
||||
help="Positional-encoding embedding dimension",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -190,7 +188,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
type=str,
|
||||
default="192,192,256,256,256,192",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -230,7 +228,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
type=str,
|
||||
default="16,32,64,-1",
|
||||
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
|
||||
" Must be just -1 if --causal=False"
|
||||
" Must be just -1 if --causal=False",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -239,7 +237,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
default="64,128,256,-1",
|
||||
help="Maximum left-contexts for causal training, measured in frames which will "
|
||||
"be converted to a number of chunks. If splitting into chunks, "
|
||||
"chunk left-context frames will be chosen randomly from this list; else not relevant."
|
||||
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
|
||||
)
|
||||
|
||||
|
||||
@ -313,10 +311,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr",
|
||||
type=float,
|
||||
default=0.045,
|
||||
help="The base learning rate."
|
||||
"--base-lr", type=float, default=0.045, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -340,15 +335,14 @@ def get_parser():
|
||||
type=float,
|
||||
default=600,
|
||||
help="Reference batch duration for purposes of adjusting batch counts for setting various "
|
||||
"schedules inside the model"
|
||||
"schedules inside the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -371,8 +365,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -415,7 +408,7 @@ def get_parser():
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -522,7 +515,7 @@ def get_params() -> AttributeDict:
|
||||
|
||||
|
||||
def _to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(',')))
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
@ -537,7 +530,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
encoder_embed = Conv2dSubsampling(
|
||||
in_channels=params.feature_dim,
|
||||
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
)
|
||||
return encoder_embed
|
||||
|
||||
@ -596,7 +589,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=int(max(params.encoder_dim.split(','))),
|
||||
encoder_dim=int(max(params.encoder_dim.split(","))),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -667,9 +660,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -748,11 +738,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -782,27 +768,24 @@ def compute_loss(
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s if batch_idx_train >= warm_step
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0 if batch_idx_train >= warm_step
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
|
||||
loss = (
|
||||
simple_loss_scale * simple_loss +
|
||||
pruned_loss_scale * pruned_loss
|
||||
)
|
||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -895,27 +878,24 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=0)
|
||||
save_checkpoint_impl(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
@ -963,7 +943,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -976,7 +955,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
@ -998,7 +976,9 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
@ -1008,8 +988,8 @@ def train_one_epoch(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, " +
|
||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -1020,9 +1000,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
@ -1039,7 +1017,9 @@ def train_one_epoch(
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
@ -1113,13 +1093,11 @@ def run(rank, world_size, args):
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank],
|
||||
find_unused_parameters=True)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(
|
||||
model, lr=params.base_lr, include_names=True),
|
||||
lr=params.base_lr, # should have no effect
|
||||
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
|
||||
@ -1139,7 +1117,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
@ -1163,9 +1141,9 @@ def run(rank, world_size, args):
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
@ -1216,8 +1194,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16,
|
||||
init_scale=1.0)
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1338,7 +1315,9 @@ def scan_pessimistic_batches_for_oom(
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -133,6 +133,7 @@ class Zipformer2(EncoderInterface):
|
||||
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
||||
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
|
||||
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
|
||||
pos_head_dim = _to_tuple(pos_head_dim)
|
||||
@ -258,7 +259,7 @@ class Zipformer2(EncoderInterface):
|
||||
if not self.causal:
|
||||
return -1, -1
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
assert len(self.chunk_size) == 1, self.chunk_size
|
||||
chunk_size = self.chunk_size[0]
|
||||
else:
|
||||
@ -267,7 +268,7 @@ class Zipformer2(EncoderInterface):
|
||||
if chunk_size == -1:
|
||||
left_context_chunks = -1
|
||||
else:
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
assert len(self.left_context_frames) == 1, self.left_context_frames
|
||||
left_context_frames = self.left_context_frames[0]
|
||||
else:
|
||||
@ -301,14 +302,14 @@ class Zipformer2(EncoderInterface):
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
outputs = []
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
feature_masks = [1.0] * len(self.encoder_dim)
|
||||
else:
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# Not support exporting a model for simulating streaming decoding
|
||||
attn_mask = None
|
||||
else:
|
||||
@ -334,7 +335,7 @@ class Zipformer2(EncoderInterface):
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
lengths = (x_lens + 1) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
@ -372,7 +373,7 @@ class Zipformer2(EncoderInterface):
|
||||
# t is frame index, shape (seq_len,)
|
||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||
# c is chunk index for each frame, shape (seq_len,)
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
c = t // chunk_size
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
@ -650,7 +651,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return None
|
||||
batch_size = x.shape[1]
|
||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||
@ -695,7 +696,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_orig = src
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
attention_skip_rate = 0.0
|
||||
else:
|
||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||
@ -713,7 +714,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
|
||||
selected_attn_weights = attn_weights[0:1]
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif not self.training and random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
@ -732,7 +733,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
@ -740,7 +741,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
ff2_skip_rate = 0.0
|
||||
else:
|
||||
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
||||
@ -754,7 +755,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
@ -762,7 +763,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
ff3_skip_rate = 0.0
|
||||
else:
|
||||
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
||||
@ -968,7 +969,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
pos_emb = self.encoder_pos(src)
|
||||
output = src
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
output = output * feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
@ -980,7 +981,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
output = output * feature_mask
|
||||
|
||||
return output
|
||||
@ -1073,7 +1074,7 @@ class BypassModule(nn.Module):
|
||||
# or (batch_size, num_channels,). This is actually the
|
||||
# scale on the non-residual term, so 0 correponds to bypassing
|
||||
# this module.
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return self.bypass_scale
|
||||
else:
|
||||
ans = limit_param_value(self.bypass_scale,
|
||||
@ -1229,12 +1230,11 @@ class SimpleDownsample(torch.nn.Module):
|
||||
d_seq_len = (seq_len + ds - 1) // ds
|
||||
|
||||
# Pad to an exact multiple of self.downsample
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
|
||||
@ -1322,11 +1322,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(0) >= T * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
|
||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||
@ -1524,7 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
use_pos_scores = False
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# We can't put random.random() in the same line
|
||||
use_pos_scores = True
|
||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
@ -1542,16 +1538,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif self.training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be
|
||||
@ -1594,7 +1600,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# half-precision output for backprop purposes.
|
||||
attn_weights = softmax(attn_scores, dim=-1)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif random.random() < 0.001 and not self.training:
|
||||
self._print_attn_entropy(attn_weights)
|
||||
@ -1672,15 +1678,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
pos_scores = torch.matmul(p, pos_emb)
|
||||
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(k_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
@ -2136,7 +2153,7 @@ class ConvolutionModule(nn.Module):
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
if not torch.jit.is_scripting() and chunk_size >= 0:
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
|
||||
# Not support exporting a model for simulated streaming decoding
|
||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
||||
|
||||
@ -503,9 +503,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -741,13 +738,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -797,7 +788,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -810,7 +800,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -506,9 +506,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -748,15 +745,9 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
|
||||
if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -805,7 +796,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -818,7 +808,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
1
egs/must_c/ST/local/compute_fbank_musan.py
Symbolic link
1
egs/must_c/ST/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
||||
155
egs/must_c/ST/local/compute_fbank_must_c.py
Executable file
155
egs/must_c/ST/local/compute_fbank_must_c.py
Executable file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This file computes fbank features of the MuST-C dataset.
|
||||
It looks for manifests in the directory "in_dir" and write
|
||||
generated features to "out_dir".
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
FeatureSet,
|
||||
LilcomChunkyWriter,
|
||||
load_manifest,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--in-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Input manifest directory",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory where generated fbank features are saved.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tgt-lang",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target language, e.g., zh, de, fr.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-jobs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of jobs for computing features",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to enable speed perturb with factors 0.9 and 1.1 on
|
||||
the train subset. False (by default) to disable speed perturb.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_must_c(
|
||||
in_dir: Path,
|
||||
out_dir: Path,
|
||||
tgt_lang: str,
|
||||
num_jobs: int,
|
||||
perturb_speed: bool,
|
||||
):
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||
|
||||
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
|
||||
|
||||
prefix = "must_c"
|
||||
suffix = "jsonl.gz"
|
||||
for p in parts:
|
||||
logging.info(f"Processing {p}")
|
||||
|
||||
cuts_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
|
||||
if perturb_speed and p == "train":
|
||||
cuts_path += "_sp"
|
||||
|
||||
cuts_path += ".jsonl.gz"
|
||||
|
||||
if Path(cuts_path).is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
recordings_filename = in_dir / f"{prefix}_recordings_en-{tgt_lang}_{p}.jsonl.gz"
|
||||
supervisions_filename = (
|
||||
in_dir / f"{prefix}_supervisions_en-{tgt_lang}_{p}_norm_rm.jsonl.gz"
|
||||
)
|
||||
assert recordings_filename.is_file(), recordings_filename
|
||||
assert supervisions_filename.is_file(), supervisions_filename
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=load_manifest(recordings_filename),
|
||||
supervisions=load_manifest(supervisions_filename),
|
||||
)
|
||||
if perturb_speed and p == "train":
|
||||
logging.info("Speed perturbing for the train dataset")
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}_sp"
|
||||
else:
|
||||
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=storage_path,
|
||||
num_jobs=num_jobs,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
logging.info(f"Saved to {cuts_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
logging.info(vars(args))
|
||||
assert args.in_dir.is_dir(), args.in_dir
|
||||
|
||||
compute_fbank_must_c(
|
||||
in_dir=args.in_dir,
|
||||
out_dir=args.out_dir,
|
||||
tgt_lang=args.tgt_lang,
|
||||
num_jobs=args.num_jobs,
|
||||
perturb_speed=args.perturb_speed,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
egs/must_c/ST/local/get_text.py
Executable file
34
egs/must_c/ST/local/get_text.py
Executable file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This file prints the text field of supervisions from cutset to the console
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Input manifest",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert args.manifest.is_file(), args.manifest
|
||||
|
||||
cutset = load_manifest_lazy(args.manifest)
|
||||
for c in cutset:
|
||||
for sup in c.supervisions:
|
||||
print(sup.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
48
egs/must_c/ST/local/get_words.py
Executable file
48
egs/must_c/ST/local/get_words.py
Executable file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This file generates words.txt from the given transcript file.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"transcript",
|
||||
type=Path,
|
||||
help="Input transcript file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert args.transcript.is_file(), args.transcript
|
||||
|
||||
word_set = set()
|
||||
with open(args.transcript) as f:
|
||||
for line in f:
|
||||
words = line.strip().split()
|
||||
for w in words:
|
||||
word_set.add(w)
|
||||
|
||||
# Note: reserved* should be kept in sync with ./local/prepare_lang_bpe.py
|
||||
reserved1 = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>"]
|
||||
reserved2 = ["#0", "<s>", "</s>"]
|
||||
|
||||
for w in reserved1 + reserved2:
|
||||
assert w not in word_set, w
|
||||
|
||||
words = sorted(list(word_set))
|
||||
words = reserved1 + words + reserved2
|
||||
|
||||
for i, w in enumerate(words):
|
||||
print(w, i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
169
egs/must_c/ST/local/normalize_punctuation.py
Normal file
169
egs/must_c/ST/local/normalize_punctuation.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import re
|
||||
|
||||
|
||||
def normalize_punctuation(s: str, lang: str) -> str:
|
||||
"""
|
||||
This function implements
|
||||
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/normalize-punctuation.perl
|
||||
|
||||
Args:
|
||||
s:
|
||||
A string to be normalized.
|
||||
lang:
|
||||
The language to which `s` belongs
|
||||
Returns:
|
||||
Return a normalized string.
|
||||
"""
|
||||
# s/\r//g;
|
||||
s = re.sub("\r", "", s)
|
||||
|
||||
# remove extra spaces
|
||||
# s/\(/ \(/g;
|
||||
s = re.sub("\(", " (", s) # add a space before (
|
||||
|
||||
# s/\)/\) /g; s/ +/ /g;
|
||||
s = re.sub("\)", ") ", s) # add a space after )
|
||||
s = re.sub(" +", " ", s) # convert multiple spaces to one
|
||||
|
||||
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
|
||||
s = re.sub("\) ([\.\!\:\?\;\,])", r")\1", s)
|
||||
|
||||
# s/\( /\(/g;
|
||||
s = re.sub("\( ", "(", s) # remove space after (
|
||||
|
||||
# s/ \)/\)/g;
|
||||
s = re.sub(" \)", ")", s) # remove space before )
|
||||
|
||||
# s/(\d) \%/$1\%/g;
|
||||
s = re.sub("(\d) \%", r"\1%", s) # remove space between a digit and %
|
||||
|
||||
# s/ :/:/g;
|
||||
s = re.sub(" :", ":", s) # remove space before :
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = re.sub(" ;", ";", s) # remove space before ;
|
||||
|
||||
# normalize unicode punctuation
|
||||
# s/\`/\'/g;
|
||||
s = re.sub("`", "'", s) # replace ` with '
|
||||
|
||||
# s/\'\'/ \" /g;
|
||||
s = re.sub("''", '"', s) # replace '' with "
|
||||
|
||||
# s/„/\"/g;
|
||||
s = re.sub("„", '"', s) # replace „ with "
|
||||
|
||||
# s/“/\"/g;
|
||||
s = re.sub("“", '"', s) # replace “ with "
|
||||
|
||||
# s/”/\"/g;
|
||||
s = re.sub("”", '"', s) # replace ” with "
|
||||
|
||||
# s/–/-/g;
|
||||
s = re.sub("–", "-", s) # replace – with -
|
||||
|
||||
# s/—/ - /g; s/ +/ /g;
|
||||
s = re.sub("—", " - ", s)
|
||||
s = re.sub(" +", " ", s) # convert multiple spaces to one
|
||||
|
||||
# s/´/\'/g;
|
||||
s = re.sub("´", "'", s)
|
||||
|
||||
# s/([a-z])‘([a-z])/$1\'$2/gi;
|
||||
s = re.sub("([a-z])‘([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
|
||||
|
||||
# s/([a-z])’([a-z])/$1\'$2/gi;
|
||||
s = re.sub("([a-z])’([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
|
||||
|
||||
# s/‘/\'/g;
|
||||
s = re.sub("‘", "'", s)
|
||||
|
||||
# s/‚/\'/g;
|
||||
s = re.sub("‚", "'", s)
|
||||
|
||||
# s/’/\"/g;
|
||||
s = re.sub("’", '"', s)
|
||||
|
||||
# s/''/\"/g;
|
||||
s = re.sub("''", '"', s)
|
||||
|
||||
# s/´´/\"/g;
|
||||
s = re.sub("´´", '"', s)
|
||||
|
||||
# s/…/.../g;
|
||||
s = re.sub("…", "...", s)
|
||||
|
||||
# French quotes
|
||||
|
||||
# s/ « / \"/g;
|
||||
s = re.sub(" « ", ' "', s)
|
||||
|
||||
# s/« /\"/g;
|
||||
s = re.sub("« ", '"', s)
|
||||
|
||||
# s/«/\"/g;
|
||||
s = re.sub("«", '"', s)
|
||||
|
||||
# s/ » /\" /g;
|
||||
s = re.sub(" » ", '" ', s)
|
||||
|
||||
# s/ »/\"/g;
|
||||
s = re.sub(" »", '"', s)
|
||||
|
||||
# s/»/\"/g;
|
||||
s = re.sub("»", '"', s)
|
||||
|
||||
# handle pseudo-spaces
|
||||
|
||||
# s/ \%/\%/g;
|
||||
s = re.sub(" %", r"%", s)
|
||||
|
||||
# s/nº /nº /g;
|
||||
s = re.sub("nº ", "nº ", s)
|
||||
|
||||
# s/ :/:/g;
|
||||
s = re.sub(" :", ":", s)
|
||||
|
||||
# s/ ºC/ ºC/g;
|
||||
s = re.sub(" ºC", " ºC", s)
|
||||
|
||||
# s/ cm/ cm/g;
|
||||
s = re.sub(" cm", " cm", s)
|
||||
|
||||
# s/ \?/\?/g;
|
||||
s = re.sub(" \?", "\?", s)
|
||||
|
||||
# s/ \!/\!/g;
|
||||
s = re.sub(" \!", "\!", s)
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = re.sub(" ;", ";", s)
|
||||
|
||||
# s/, /, /g; s/ +/ /g;
|
||||
s = re.sub(", ", ", ", s)
|
||||
s = re.sub(" +", " ", s)
|
||||
|
||||
if lang == "en":
|
||||
# English "quotation," followed by comma, style
|
||||
# s/\"([,\.]+)/$1\"/g;
|
||||
s = re.sub('"([,\.]+)', r'\1"', s)
|
||||
elif lang in ("cs", "cz"):
|
||||
# Czech is confused
|
||||
pass
|
||||
else:
|
||||
# German/Spanish/French "quotation", followed by comma, style
|
||||
# s/,\"/\",/g;
|
||||
s = re.sub(',"', '",', s)
|
||||
|
||||
# s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence
|
||||
s = re.sub('(\.+)"(\s*[^<])', r'"\1\2', s)
|
||||
|
||||
if lang in ("de", "es", "cz", "cs", "fr"):
|
||||
# s/(\d) (\d)/$1,$2/g;
|
||||
s = re.sub("(\d) (\d)", r"\1,\2", s)
|
||||
else:
|
||||
# s/(\d) (\d)/$1.$2/g;
|
||||
s = re.sub("(\d) (\d)", r"\1.\2", s)
|
||||
|
||||
return s
|
||||
1
egs/must_c/ST/local/prepare_lang.py
Symbolic link
1
egs/must_c/ST/local/prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang.py
|
||||
1
egs/must_c/ST/local/prepare_lang_bpe.py
Symbolic link
1
egs/must_c/ST/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
||||
96
egs/must_c/ST/local/preprocess_must_c.py
Executable file
96
egs/must_c/ST/local/preprocess_must_c.py
Executable file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This script normalizes transcripts from supervisions.
|
||||
|
||||
Usage:
|
||||
./local/preprocess_must_c.py \
|
||||
--manifest-dir ./data/manifests/v1.0/ \
|
||||
--tgt-lang de
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from normalize_punctuation import normalize_punctuation
|
||||
from remove_non_native_characters import remove_non_native_characters
|
||||
from remove_punctuation import remove_punctuation
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Manifest directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tgt-lang",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target language, e.g., zh, de, fr.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
|
||||
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
|
||||
remove_non_native_characters_lang = partial(
|
||||
remove_non_native_characters, lang=tgt_lang
|
||||
)
|
||||
|
||||
prefix = "must_c"
|
||||
suffix = "jsonl.gz"
|
||||
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
|
||||
for p in parts:
|
||||
logging.info(f"Processing {p}")
|
||||
name = f"en-{tgt_lang}_{p}"
|
||||
|
||||
# norm: normalization
|
||||
# rm: remove punctuation
|
||||
dst_name = manifest_dir / f"must_c_supervisions_{name}_norm_rm.jsonl.gz"
|
||||
if dst_name.is_file():
|
||||
logging.info(f"{dst_name} exists - skipping")
|
||||
continue
|
||||
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=name,
|
||||
output_dir=manifest_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
types=("supervisions",),
|
||||
)
|
||||
if name not in manifests:
|
||||
raise RuntimeError(f"Processing {p} failed.")
|
||||
|
||||
supervisions = manifests[name]["supervisions"]
|
||||
supervisions = supervisions.transform_text(normalize_punctuation_lang)
|
||||
supervisions = supervisions.transform_text(remove_punctuation)
|
||||
supervisions = supervisions.transform_text(lambda x: x.lower())
|
||||
supervisions = supervisions.transform_text(remove_non_native_characters_lang)
|
||||
supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x))
|
||||
|
||||
supervisions.to_file(dst_name)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
logging.info(vars(args))
|
||||
assert args.manifest_dir.is_dir(), args.manifest_dir
|
||||
|
||||
preprocess_must_c(
|
||||
manifest_dir=args.manifest_dir,
|
||||
tgt_lang=args.tgt_lang,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
21
egs/must_c/ST/local/remove_non_native_characters.py
Executable file
21
egs/must_c/ST/local/remove_non_native_characters.py
Executable file
@ -0,0 +1,21 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def remove_non_native_characters(s: str, lang: str):
|
||||
if lang == "de":
|
||||
# ä -> ae
|
||||
# ö -> oe
|
||||
# ü -> ue
|
||||
# ß -> ss
|
||||
|
||||
s = re.sub("ä", "ae", s)
|
||||
s = re.sub("ö", "oe", s)
|
||||
s = re.sub("ü", "ue", s)
|
||||
s = re.sub("ß", "ss", s)
|
||||
# keep only a-z and spaces
|
||||
# note: ' is removed
|
||||
s = re.sub(r"[^a-z\s]", "", s)
|
||||
|
||||
return s
|
||||
41
egs/must_c/ST/local/remove_punctuation.py
Normal file
41
egs/must_c/ST/local/remove_punctuation.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
def remove_punctuation(s: str) -> str:
|
||||
"""
|
||||
It implements https://github.com/espnet/espnet/blob/master/utils/remove_punctuation.pl
|
||||
"""
|
||||
|
||||
# Remove punctuation except apostrophe
|
||||
# s/<space>/spacemark/g; # for scoring
|
||||
s = re.sub("<space>", "spacemark", s)
|
||||
|
||||
# s/'/apostrophe/g;
|
||||
s = re.sub("'", "apostrophe", s)
|
||||
|
||||
# s/[[:punct:]]//g;
|
||||
s = s.translate(str.maketrans("", "", string.punctuation))
|
||||
# string punctuation returns the following string
|
||||
# !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
|
||||
# See
|
||||
# https://stackoverflow.com/questions/265960/best-way-to-strip-punctuation-from-a-string
|
||||
|
||||
# s/apostrophe/'/g;
|
||||
s = re.sub("apostrophe", "'", s)
|
||||
|
||||
# s/spacemark/<space>/g; # for scoring
|
||||
s = re.sub("spacemark", "<space>", s)
|
||||
|
||||
# remove whitespace
|
||||
# s/\s+/ /g;
|
||||
s = re.sub("\s+", " ", s)
|
||||
|
||||
# s/^\s+//;
|
||||
s = re.sub("^\s+", "", s)
|
||||
|
||||
# s/\s+$//;
|
||||
s = re.sub("\s+$", "", s)
|
||||
|
||||
return s
|
||||
197
egs/must_c/ST/local/test_normalize_punctuation.py
Executable file
197
egs/must_c/ST/local/test_normalize_punctuation.py
Executable file
@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from normalize_punctuation import normalize_punctuation
|
||||
|
||||
|
||||
def test_normalize_punctuation():
|
||||
# s/\r//g;
|
||||
s = "a\r\nb\r\n"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert "\r" not in n
|
||||
assert len(s) - 2 == len(n), (len(s), len(n))
|
||||
|
||||
# s/\(/ \(/g;
|
||||
s = "(ab (c"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == " (ab (c", n
|
||||
|
||||
# s/\)/\) /g;
|
||||
s = "a)b c)"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a) b c) "
|
||||
|
||||
# s/ +/ /g;
|
||||
s = " a b c d "
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == " a b c d "
|
||||
|
||||
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
|
||||
for i in ".!:?;,":
|
||||
s = f"a) {i}"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == f"a){i}"
|
||||
|
||||
# s/\( /\(/g;
|
||||
s = "a( b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a (b", n
|
||||
|
||||
# s/ \)/\)/g;
|
||||
s = "ab ) a"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "ab) a", n
|
||||
|
||||
# s/(\d) \%/$1\%/g;
|
||||
s = "1 %a"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "1%a", n
|
||||
|
||||
# s/ :/:/g;
|
||||
s = "a :"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a:", n
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = "a ;"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a;", n
|
||||
|
||||
# s/\`/\'/g;
|
||||
s = "`a`"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "'a'", n
|
||||
|
||||
# s/\'\'/ \" /g;
|
||||
s = "''a''"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/„/\"/g;
|
||||
s = '„a"'
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/“/\"/g;
|
||||
s = "“a„"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/”/\"/g;
|
||||
s = "“a”"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/–/-/g;
|
||||
s = "a–b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a-b", n
|
||||
|
||||
# s/—/ - /g; s/ +/ /g;
|
||||
s = "a—b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a - b", n
|
||||
|
||||
# s/´/\'/g;
|
||||
s = "a´b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'b", n
|
||||
|
||||
# s/([a-z])‘([a-z])/$1\'$2/gi;
|
||||
for i in "‘’":
|
||||
s = f"a{i}B"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'B", n
|
||||
|
||||
s = f"A{i}B"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "A'B", n
|
||||
|
||||
s = f"A{i}b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "A'b", n
|
||||
|
||||
# s/‘/\'/g;
|
||||
# s/‚/\'/g;
|
||||
for i in "‘‚":
|
||||
s = f"a{i}b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'b", n
|
||||
|
||||
# s/’/\"/g;
|
||||
s = "’"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/''/\"/g;
|
||||
s = "''"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/´´/\"/g;
|
||||
s = "´´"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/…/.../g;
|
||||
s = "…"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "...", n
|
||||
|
||||
# s/ « / \"/g;
|
||||
s = "a « b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a "b', n
|
||||
|
||||
# s/« /\"/g;
|
||||
s = "a « b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a "b', n
|
||||
|
||||
# s/«/\"/g;
|
||||
s = "a«b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a"b', n
|
||||
|
||||
# s/ » /\" /g;
|
||||
s = " » "
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '" ', n
|
||||
|
||||
# s/ »/\"/g;
|
||||
s = " »"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/»/\"/g;
|
||||
s = "»"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/ \%/\%/g;
|
||||
s = " %"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "%", n
|
||||
|
||||
# s/ :/:/g;
|
||||
s = " :"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == ":", n
|
||||
|
||||
# s/(\d) (\d)/$1.$2/g;
|
||||
s = "2 3"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "2.3", n
|
||||
|
||||
# s/(\d) (\d)/$1,$2/g;
|
||||
s = "2 3"
|
||||
n = normalize_punctuation(s, lang="de")
|
||||
assert n == "2,3", n
|
||||
|
||||
|
||||
def main():
|
||||
test_normalize_punctuation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
26
egs/must_c/ST/local/test_remove_non_native_characters.py
Executable file
26
egs/must_c/ST/local/test_remove_non_native_characters.py
Executable file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from remove_non_native_characters import remove_non_native_characters
|
||||
|
||||
|
||||
def test_remove_non_native_characters():
|
||||
s = "Ich heiße xxx好的01 fangjun".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "ich heisse xxx fangjun", n
|
||||
|
||||
s = "äÄ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "aeae", n
|
||||
|
||||
s = "öÖ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "oeoe", n
|
||||
|
||||
s = "üÜ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "ueue", n
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_remove_non_native_characters()
|
||||
17
egs/must_c/ST/local/test_remove_punctuation.py
Executable file
17
egs/must_c/ST/local/test_remove_punctuation.py
Executable file
@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from remove_punctuation import remove_punctuation
|
||||
|
||||
|
||||
def test_remove_punctuation():
|
||||
s = "a,b'c!#"
|
||||
n = remove_punctuation(s)
|
||||
assert n == "ab'c", n
|
||||
|
||||
s = " ab " # remove leading and trailing spaces
|
||||
n = remove_punctuation(s)
|
||||
assert n == "ab", n
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_remove_punctuation()
|
||||
1
egs/must_c/ST/local/train_bpe_model.py
Symbolic link
1
egs/must_c/ST/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
||||
1
egs/must_c/ST/local/validate_bpe_lexicon.py
Symbolic link
1
egs/must_c/ST/local/validate_bpe_lexicon.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/validate_bpe_lexicon.py
|
||||
173
egs/must_c/ST/prepare.sh
Executable file
173
egs/must_c/ST/prepare.sh
Executable file
@ -0,0 +1,173 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=10
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
version=v1.0
|
||||
tgt_lang=de
|
||||
dl_dir=$PWD/download
|
||||
|
||||
must_c_dir=$dl_dir/must-c/$version/en-$tgt_lang/data
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files.
|
||||
# - $dl_dir/must-c/$version/en-$tgt_lang/data/{dev,train,tst-COMMON,tst-HE}
|
||||
#
|
||||
# Please go to https://ict.fbk.eu/must-c-releases/
|
||||
# to download and untar the dataset if you have not already done this.
|
||||
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate
|
||||
# data/lang_bpe_${tgt_lang}_xxx
|
||||
# data/lang_bpe_${tgt_lang}_yyy
|
||||
# if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 5000
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ ! -d $must_c_dir ]; then
|
||||
log "$must_c_dir does not exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for d in dev train tst-COMMON tst-HE; do
|
||||
if [ ! -d $must_c_dir/$d ]; then
|
||||
log "$must_c_dir/$d does not exist!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download musan"
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to $dl_dir/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare must-c $version manifest for target language $tgt_lang"
|
||||
mkdir -p data/manifests/$version
|
||||
if [ ! -e data/manifests/$version/.${tgt_lang}.manifests.done ]; then
|
||||
lhotse prepare must-c \
|
||||
-j $nj \
|
||||
--tgt-lang $tgt_lang \
|
||||
$dl_dir/must-c/$version/ \
|
||||
data/manifests/$version/
|
||||
|
||||
touch data/manifests/$version/.${tgt_lang}.manifests.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Text normalization for $version with target language $tgt_lang"
|
||||
if [ ! -f ./data/manifests/$version/.$tgt_lang.norm.done ]; then
|
||||
./local/preprocess_must_c.py \
|
||||
--manifest-dir ./data/manifests/$version/ \
|
||||
--tgt-lang $tgt_lang
|
||||
touch ./data/manifests/$version/.$tgt_lang.norm.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.musan.done ]; then
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for $version with target language $tgt_lang"
|
||||
mkdir -p data/fbank/$version/
|
||||
if [ ! -e data/fbank/$version/.$tgt_lang.done ]; then
|
||||
./local/compute_fbank_must_c.py \
|
||||
--in-dir ./data/manifests/$version/ \
|
||||
--out-dir ./data/fbank/$version/ \
|
||||
--tgt-lang $tgt_lang \
|
||||
--num-jobs $nj
|
||||
|
||||
./local/compute_fbank_must_c.py \
|
||||
--in-dir ./data/manifests/$version/ \
|
||||
--out-dir ./data/fbank/$version/ \
|
||||
--tgt-lang $tgt_lang \
|
||||
--num-jobs $nj \
|
||||
--perturb-speed 1
|
||||
|
||||
touch data/fbank/$version/.$tgt_lang.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare BPE based lang for $version with target language $tgt_lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}/$version/$tgt_lang/
|
||||
mkdir -p $lang_dir
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
./local/get_text.py ./data/fbank/$version/must_c_feats_en-${tgt_lang}_train.jsonl.gz > $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/words.txt ]; then
|
||||
./local/get_words.py $lang_dir/transcript_words.txt > $lang_dir/words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
|
||||
log "Validating $lang_dir/lexicon.txt"
|
||||
./local/validate_bpe_lexicon.py \
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--bpe-model $lang_dir/bpe.model
|
||||
fi
|
||||
done
|
||||
fi
|
||||
1
egs/must_c/ST/shared
Symbolic link
1
egs/must_c/ST/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
||||
1
egs/peoples_speech/ASR/local/compute_fbank_musan.py
Symbolic link
1
egs/peoples_speech/ASR/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
||||
154
egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py
Executable file
154
egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py
Executable file
@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of dataloading workers used for reading the audio.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=float,
|
||||
default=600.0,
|
||||
help="The maximum number of audio seconds in a batch."
|
||||
"Determines batch size dynamically.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
type=int,
|
||||
required=True,
|
||||
help="The number of splits of the train subset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Process pieces starting from this number (inclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (exclusive).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_peoples_speech_splits(args):
|
||||
subsets = ("dirty", "dirty_sa", "clean", "clean_sa")
|
||||
num_splits = args.num_splits
|
||||
output_dir = f"data/fbank/peoples_speech_train_split"
|
||||
output_dir = Path(output_dir)
|
||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||
|
||||
num_digits = 8
|
||||
|
||||
start = args.start
|
||||
stop = args.stop
|
||||
if stop < start:
|
||||
stop = num_splits
|
||||
|
||||
stop = min(stop, num_splits)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||
set_caching_enabled(False)
|
||||
|
||||
for partition in subsets:
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
logging.info(f"Processing {partition}: {idx}")
|
||||
|
||||
cuts_path = output_dir / f"peoples_speech_cuts_{partition}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = (
|
||||
output_dir / f"peoples_speech_cuts_{partition}_raw.{idx}.jsonl.gz"
|
||||
)
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Splitting cuts into smaller chunks.")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/peoples_speech_feats_{partition}_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
compute_fbank_peoples_speech_splits(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
93
egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py
Executable file
93
egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py
Executable file
@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the People's Speech dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from filter_cuts import filter_cuts
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_peoples_speech_valid_test():
|
||||
src_dir = Path(f"data/manifests")
|
||||
output_dir = Path(f"data/fbank")
|
||||
num_workers = 42
|
||||
batch_duration = 600
|
||||
|
||||
subsets = ("validation", "test")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
for partition in subsets:
|
||||
cuts_path = output_dir / f"peoples_speech_cuts_{partition}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
|
||||
raw_cuts_path = output_dir / f"peoples_speech_cuts_{partition}_raw.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Splitting cuts into smaller chunks")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/peoples_speech_feats_{partition}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_peoples_speech_valid_test()
|
||||
1
egs/peoples_speech/ASR/local/filter_cuts.py
Symbolic link
1
egs/peoples_speech/ASR/local/filter_cuts.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/filter_cuts.py
|
||||
1
egs/peoples_speech/ASR/local/prepare_lang_bpe.py
Symbolic link
1
egs/peoples_speech/ASR/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
||||
123
egs/peoples_speech/ASR/local/preprocess_peoples_speech.py
Executable file
123
egs/peoples_speech/ASR/local/preprocess_peoples_speech.py
Executable file
@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def normalize_text(utt: str) -> str:
|
||||
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
|
||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||
|
||||
|
||||
def preprocess_peoples_speech(dataset: Optional[str] = None):
|
||||
src_dir = Path(f"data/manifests")
|
||||
output_dir = Path(f"data/fbank")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
if dataset is None:
|
||||
dataset_parts = (
|
||||
"validation",
|
||||
"test",
|
||||
"dirty",
|
||||
"dirty_sa",
|
||||
"clean",
|
||||
"clean_sa",
|
||||
)
|
||||
else:
|
||||
dataset_parts = dataset.split(" ", -1)
|
||||
|
||||
logging.info("Loading manifest, it may takes 8 minutes")
|
||||
prefix = f"peoples_speech"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
for partition, m in manifests.items():
|
||||
logging.info(f"Processing {partition}")
|
||||
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
|
||||
if raw_cuts_path.is_file():
|
||||
logging.info(f"{partition} already exists - skipping")
|
||||
continue
|
||||
|
||||
logging.info(f"Normalizing text in {partition}")
|
||||
i = 0
|
||||
for sup in m["supervisions"]:
|
||||
text = str(sup.text)
|
||||
orig_text = text
|
||||
sup.text = normalize_text(sup.text)
|
||||
text = str(sup.text)
|
||||
if i < 10 and len(orig_text) != len(text):
|
||||
logging.info(
|
||||
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
|
||||
)
|
||||
i += 1
|
||||
|
||||
# Create long-recording cut manifests.
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
).resample(16000)
|
||||
|
||||
# Run data augmentation that needs to be done in the
|
||||
# time domain.
|
||||
logging.info(f"Saving to {raw_cuts_path}")
|
||||
cut_set.to_file(raw_cuts_path)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
preprocess_peoples_speech(dataset=args.dataset)
|
||||
logging.info("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/peoples_speech/ASR/local/train_bpe_model.py
Symbolic link
1
egs/peoples_speech/ASR/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
||||
1
egs/peoples_speech/ASR/local/validate_bpe_lexicon.py
Symbolic link
1
egs/peoples_speech/ASR/local/validate_bpe_lexicon.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/validate_bpe_lexicon.py
|
||||
247
egs/peoples_speech/ASR/prepare.sh
Executable file
247
egs/peoples_speech/ASR/prepare.sh
Executable file
@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=32
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# Split data/set to a number of pieces
|
||||
# This is to avoid OOM during feature extraction.
|
||||
num_per_split=4000
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/peoples_speech
|
||||
# This directory contains the following files downloaded from
|
||||
# https://huggingface.co/datasets/MLCommons/peoples_speech
|
||||
#
|
||||
# - test
|
||||
# - train
|
||||
# - validation
|
||||
#
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 5000
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/peoples_speech,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/peoples_speech $dl_dir/peoples_speech
|
||||
#
|
||||
if [ ! -d $dl_dir/peoples_speech/train ]; then
|
||||
git lfs install
|
||||
git clone https://huggingface.co/datasets/MLCommons/peoples_speech
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare People's Speech manifest"
|
||||
# We assume that you have downloaded the People's Speech corpus
|
||||
# to $dl_dir/peoples_speech
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.peoples_speech.done ]; then
|
||||
lhotse prepare peoples-speech -j $nj $dl_dir/peoples_speech data/manifests
|
||||
touch data/manifests/.peoples_speech.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Preprocess People's Speech manifest"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.preprocess_complete ]; then
|
||||
./local/preprocess_peoples_speech.py
|
||||
touch data/fbank/.preprocess_complete
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for valid and test subsets of People's Speech"
|
||||
if [ ! -e data/fbank/.peoples_speech_valid_test.done ]; then
|
||||
./local/compute_fbank_peoples_speech_valid_test.py
|
||||
touch data/fbank/.peoples_speech_valid_test.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Split train subset into pieces"
|
||||
split_dir=data/fbank/peoples_speech_train_split
|
||||
if [ ! -e $split_dir/.peoples_speech_dirty_split.done ]; then
|
||||
lhotse split-lazy ./data/fbank/peoples_speech_cuts_dirty_raw.jsonl.gz $split_dir $num_per_split
|
||||
touch $split_dir/.peoples_speech_dirty_split.done
|
||||
fi
|
||||
|
||||
if [ ! -e $split_dir/.peoples_speech_dirty_sa_split.done ]; then
|
||||
lhotse split-lazy ./data/fbank/peoples_speech_cuts_dirty_sa_raw.jsonl.gz $split_dir $num_per_split
|
||||
touch $split_dir/.peoples_speech_dirty_sa_split.done
|
||||
fi
|
||||
|
||||
if [ ! -e $split_dir/.peoples_speech_clean_split.done ]; then
|
||||
lhotse split-lazy ./data/fbank/peoples_speech_cuts_clean_raw.jsonl.gz $split_dir $num_per_split
|
||||
touch $split_dir/.peoples_speech_clean_split.done
|
||||
fi
|
||||
|
||||
if [ ! -e $split_dir/.peoples_speech_clean_sa_split.done ]; then
|
||||
lhotse split-lazy ./data/fbank/peoples_speech_cuts_clean_sa_raw.jsonl.gz $split_dir $num_per_split
|
||||
touch $split_dir/.peoples_speech_clean_sa_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Compute features for train subset of People's Speech"
|
||||
if [ ! -e data/fbank/.peoples_speech_train.done ]; then
|
||||
./local/compute_fbank_peoples_speech_splits.py \
|
||||
--num-workers $nj \
|
||||
--batch-duration 600 \
|
||||
--start 0 \
|
||||
--num-splits 2000
|
||||
touch data/fbank/.peoples_speech_train.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Compute fbank for musan"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.musan.done ]; then
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Prepare BPE based lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
log "Generate data for BPE training"
|
||||
file=$(
|
||||
find "data/fbank/peoples_speech_cuts_dirty_raw.jsonl.gz"
|
||||
find "data/fbank/peoples_speech_cuts_dirty_sa_raw.jsonl.gz"
|
||||
find "data/fbank/peoples_speech_cuts_clean_raw.jsonl.gz"
|
||||
find "data/fbank/peoples_speech_cuts_clean_sa_raw.jsonl.gz"
|
||||
)
|
||||
gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt
|
||||
|
||||
# Ensure space only appears once
|
||||
sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
|
||||
sed -i 's/ +/ /g' $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/words.txt ]; then
|
||||
cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
|
||||
| sort -u | sed '/^$/d' > $lang_dir/words.txt
|
||||
(echo '!SIL'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
|
||||
cat - $lang_dir/words.txt | sort | uniq | awk '
|
||||
BEGIN {
|
||||
print "<eps> 0";
|
||||
}
|
||||
{
|
||||
if ($1 == "<s>") {
|
||||
print "<s> is in the vocabulary!" | "cat 1>&2"
|
||||
exit 1;
|
||||
}
|
||||
if ($1 == "</s>") {
|
||||
print "</s> is in the vocabulary!" | "cat 1>&2"
|
||||
exit 1;
|
||||
}
|
||||
printf("%s %d\n", $1, NR);
|
||||
}
|
||||
END {
|
||||
printf("#0 %d\n", NR+1);
|
||||
printf("<s> %d\n", NR+2);
|
||||
printf("</s> %d\n", NR+3);
|
||||
}' > $lang_dir/words || exit 1;
|
||||
mv $lang_dir/words $lang_dir/words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
|
||||
log "Validating $lang_dir/lexicon.txt"
|
||||
./local/validate_bpe_lexicon.py \
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--bpe-model $lang_dir/bpe.model
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L.fst ]; then
|
||||
log "Converting L.pt to L.fst"
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
--olabels aux_labels \
|
||||
$lang_dir/L.pt \
|
||||
$lang_dir/L.fst
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.fst ]; then
|
||||
log "Converting L_disambig.pt to L_disambig.fst"
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
--olabels aux_labels \
|
||||
$lang_dir/L_disambig.pt \
|
||||
$lang_dir/L_disambig.fst
|
||||
fi
|
||||
done
|
||||
fi
|
||||
1
egs/peoples_speech/ASR/shared
Symbolic link
1
egs/peoples_speech/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared/
|
||||
@ -443,9 +443,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -648,13 +645,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -685,7 +676,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -697,7 +687,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -513,9 +513,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -726,13 +723,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
# print(batch["supervisions"])
|
||||
@ -775,7 +766,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -788,7 +778,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -566,9 +566,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -798,13 +795,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -851,7 +842,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -864,7 +854,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
|
||||
@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=300,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
@ -364,7 +364,7 @@ class WenetSpeechAsrDataModule:
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
logging.info("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
|
||||
@ -92,7 +92,7 @@ When training with the L subset, the streaming usage:
|
||||
--causal-convolution 1 \
|
||||
--decode-chunk-size 16 \
|
||||
--left-context 64
|
||||
|
||||
|
||||
(4) modified beam search with RNNLM shallow fusion
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 35 \
|
||||
@ -112,8 +112,10 @@ When training with the L subset, the streaming usage:
|
||||
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -133,7 +135,8 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -307,6 +310,26 @@ def get_parser():
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-shallow-fusion",
|
||||
type=str2bool,
|
||||
@ -362,6 +385,7 @@ def decode_one_batch(
|
||||
lexicon: Lexicon,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
@ -402,14 +426,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
@ -448,6 +471,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -509,7 +533,12 @@ def decode_one_batch(
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -518,6 +547,7 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
@ -567,6 +597,7 @@ def decode_dataset(
|
||||
lexicon=lexicon,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
context_graph=context_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
@ -646,6 +677,12 @@ def main():
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
@ -655,6 +692,10 @@ def main():
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -684,11 +725,15 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
@ -816,6 +861,19 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -833,15 +891,16 @@ def main():
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||
|
||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user