Merge branch 'master' into wenetspeech

This commit is contained in:
pkufool 2023-06-13 16:18:28 +08:00
commit a1b12cf4e9
106 changed files with 4482 additions and 1290 deletions

View File

@ -58,6 +58,7 @@ Usage:
import argparse import argparse
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -76,6 +77,8 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -211,6 +214,26 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -222,6 +245,7 @@ def decode_one_batch(
token_table: k2.SymbolTable, token_table: k2.SymbolTable,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -285,6 +309,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
) )
else: else:
hyp_tokens = [] hyp_tokens = []
@ -324,7 +349,12 @@ def decode_one_batch(
): hyps ): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: hyps}
def decode_dataset( def decode_dataset(
@ -333,6 +363,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
token_table: k2.SymbolTable, token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -377,6 +408,7 @@ def decode_dataset(
model=model, model=model,
token_table=token_table, token_table=token_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
batch=batch, batch=batch,
) )
@ -407,16 +439,17 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) # we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
store_transcripts(filename=recog_path, texts=results_char)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -457,6 +490,12 @@ def main():
"fast_beam_search", "fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -470,6 +509,10 @@ def main():
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-contexts-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -490,6 +533,11 @@ def main():
params.blank_id = 0 params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
@ -586,6 +634,19 @@ def main():
else: else:
decoding_graph = None decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts_text = []
for line in open(params.context_file).readlines():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -608,6 +669,7 @@ def main():
model=model, model=model,
token_table=lexicon.token_table, token_table=lexicon.token_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
) )
save_results( save_results(

View File

@ -577,9 +577,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -806,13 +803,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -859,7 +850,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -872,7 +862,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -580,9 +580,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -809,13 +806,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -862,7 +853,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -875,7 +865,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -567,9 +567,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -799,13 +796,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -852,7 +843,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -865,7 +855,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -512,9 +512,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -725,13 +722,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
# print(batch["supervisions"]) # print(batch["supervisions"])
@ -774,7 +765,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -787,7 +777,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -554,9 +554,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -779,13 +776,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -832,7 +823,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -845,7 +835,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -549,9 +549,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -770,13 +767,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -823,7 +814,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -836,7 +826,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -63,6 +63,14 @@ log() {
log "dl_dir: $dl_dir" log "dl_dir: $dl_dir"
if ! command -v ffmpeg &> /dev/null; then
echo "This dataset requires ffmpeg"
echo "Please install ffmpeg first"
echo ""
echo " sudo apt-get install ffmpeg"
exit 1
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data" log "Stage 0: Download data"

View File

@ -567,9 +567,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -799,13 +796,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -852,7 +843,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -865,7 +855,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -606,9 +606,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -835,13 +832,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -889,7 +880,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -902,7 +892,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -607,9 +607,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -836,13 +833,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -890,7 +881,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -903,7 +893,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -462,9 +462,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -674,13 +671,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -712,7 +703,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -725,7 +715,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -410,9 +410,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -675,13 +672,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
batch_name = batch["supervisions"]["uttid"] batch_name = batch["supervisions"]["uttid"]
@ -736,7 +727,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -749,7 +739,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -550,9 +550,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -771,13 +768,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -819,7 +810,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -832,7 +822,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -550,9 +550,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -771,13 +768,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -819,7 +810,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -832,7 +822,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -552,9 +552,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -773,13 +770,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -821,7 +812,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -834,7 +824,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -107,7 +107,7 @@ fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest" log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus # We assume that you have downloaded the musan corpus
# to data/musan # to $dl_dir/musan
mkdir -p data/manifests mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests lhotse prepare musan $dl_dir/musan data/manifests

View File

@ -1,117 +0,0 @@
#!/usr/bin/env bash
set -eou pipefail
nj=16
stage=-1
stop_stage=100
# Split data/${lang}set to this number of pieces
# This is to avoid OOM during feature extraction.
num_splits=1000
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/$release/$lang
# This directory contains the following files downloaded from
# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz
#
# - clips
# - dev.tsv
# - invalidated.tsv
# - other.tsv
# - reported.tsv
# - test.tsv
# - train.tsv
# - validated.tsv
dl_dir=$PWD/download
release=cv-corpus-13.0-2023-03-09
lang=en
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data/${lang}".
# You can safely remove "data/${lang}" and rerun this script to regenerate it.
mkdir -p data/${lang}
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/$release,
# you can create a symlink
#
# ln -sfv /path/to/$release $dl_dir/$release
#
if [ ! -d $dl_dir/$release/$lang/clips ]; then
lhotse download commonvoice --languages $lang --release $release $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare CommonVoice manifest"
# We assume that you have downloaded the CommonVoice corpus
# to $dl_dir/$release
mkdir -p data/${lang}/manifests
if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then
lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests
touch data/${lang}/manifests/.cv-${lang}.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Preprocess CommonVoice manifest"
if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then
./local/preprocess_commonvoice.py --language $lang
touch data/${lang}/fbank/.preprocess_complete
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for dev and test subsets of CommonVoice"
mkdir -p data/${lang}/fbank
if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then
./local/compute_fbank_commonvoice_dev_test.py --language $lang
touch data/${lang}/fbank/.cv-${lang}_dev_test.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split train subset into ${num_splits} pieces"
split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits}
if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then
lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir
touch $split_dir/.cv-${lang}_train_split.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute features for train subset of CommonVoice"
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
./local/compute_fbank_commonvoice_splits.py \
--num-workers $nj \
--batch-duration 600 \
--start 0 \
--num-splits $num_splits \
--language $lang
touch data/${lang}/fbank/.cv-${lang}_train.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Combine features for train"
if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then
pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz")
lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz
fi
fi

View File

@ -1,159 +0,0 @@
#!/usr/bin/env bash
set -eou pipefail
nj=15
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/GigaSpeech
# You can find audio, dict, GigaSpeech.json inside it.
# You can apply for the download credentials by following
# https://github.com/SpeechColab/GigaSpeech#download
# Number of hours for GigaSpeech subsets
# XL 10k hours
# L 2.5k hours
# M 1k hours
# S 250 hours
# XS 10 hours
# DEV 12 hours
# Test 40 hours
# Split XL subset to this number of pieces
# This is to avoid OOM during feature extraction.
num_splits=2000
# We use lazy split from lhotse.
# The XL subset (10k hours) contains 37956 cuts without speed perturbing.
# We want to split it into 2000 splits, so each split
# contains about 37956 / 2000 = 19 cuts. As a result, there will be 1998 splits.
chunk_size=19 # number of cuts in each split. The last split may contain fewer cuts.
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
[ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech
# If you have pre-downloaded it to /path/to/GigaSpeech,
# you can create a symlink
#
# ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech
#
if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then
# Check credentials.
if [ ! -f $dl_dir/password ]; then
echo -n "$0: Please apply for the download credentials by following"
echo -n "https://github.com/SpeechColab/GigaSpeech#dataset-download"
echo " and save it to $dl_dir/password."
exit 1;
fi
PASSWORD=`cat $dl_dir/password 2>/dev/null`
if [ -z "$PASSWORD" ]; then
echo "$0: Error, $dl_dir/password is empty."
exit 1;
fi
PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1`
if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then
echo "$0: Error, invalid $dl_dir/password."
exit 1;
fi
# Download XL, DEV and TEST sets by default.
lhotse download gigaspeech \
--subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
--host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare GigaSpeech manifest (may take 30 minutes)"
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
if [ ! -f data/manifests/.gigaspeech.done ]; then
mkdir -p data/manifests
lhotse prepare gigaspeech \
--subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
-j $nj \
$dl_dir/GigaSpeech data/manifests
touch data/manifests/.gigaspeech.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Preprocess GigaSpeech manifest"
if [ ! -f data/fbank/.gigaspeech_preprocess.done ]; then
log "It may take 2 hours for this stage"
./local/preprocess_gigaspeech.py
touch data/fbank/.gigaspeech_preprocess.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
if [ ! -f data/fbank/.gigaspeech_dev_test.done ]; then
./local/compute_fbank_gigaspeech_dev_test.py
touch data/fbank/.gigaspeech_dev_test.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split XL subset into ${num_splits} pieces"
split_dir=data/fbank/gigaspeech_XL_split_${num_splits}
if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then
lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $chunk_size
touch $split_dir/.gigaspeech_XL_split.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute features for XL"
# Note: The script supports --start and --stop options.
# You can use several machines to compute the features in parallel.
if [ ! -f data/fbank/.gigaspeech_XL.done ]; then
./local/compute_fbank_gigaspeech_splits.py \
--num-workers $nj \
--batch-duration 600 \
--num-splits $num_splits
touch data/fbank/.gigaspeech_XL.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Combine features for XL (may take 15 hours)"
if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then
pieces=$(find data/fbank/gigaspeech_XL_split_${num_splits} -name "gigaspeech_cuts_XL.*.jsonl.gz")
lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz
fi
fi

View File

@ -1,330 +0,0 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=16
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/LibriSpeech
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
#
# - 3-gram.pruned.1e-7.arpa.gz
# - 3-gram.pruned.1e-7.arpa
# - 4-gram.arpa.gz
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
# - librispeech-lm-norm.txt.gz
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
# Split all dataset to this number of pieces and mix each dataset pieces
# into multidataset pieces with shuffling.
num_splits=1998
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
# 5000
# 2000
# 1000
500
)
# multidataset list.
# LibriSpeech and musan are required.
# The others are optional.
multidataset=(
"gigaspeech",
"commonvoice",
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
log "Dataset: LibriSpeech and musan"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download LM"
mkdir -p $dl_dir/lm
if [ ! -e $dl_dir/lm/.done ]; then
./local/download_lm.py --out-dir=$dl_dir/lm
touch $dl_dir/lm/.done
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink
#
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
#
if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
lhotse download librispeech --full $dl_dir
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LibriSpeech manifest"
# We assume that you have downloaded the LibriSpeech corpus
# to $dl_dir/LibriSpeech
mkdir -p data/manifests
if [ ! -e data/manifests/.librispeech.done ]; then
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
touch data/manifests/.librispeech.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for librispeech"
mkdir -p data/fbank
if [ ! -e data/fbank/.librispeech.done ]; then
./local/compute_fbank_librispeech.py --perturb-speed False
touch data/fbank/.librispeech.done
fi
if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz
fi
if [ ! -e data/fbank/.librispeech-validated.done ]; then
log "Validating data/fbank for LibriSpeech"
parts=(
train-clean-100
train-clean-360
train-other-500
test-clean
test-other
dev-clean
dev-other
)
for part in ${parts[@]}; do
python3 ./local/validate_manifest.py \
data/fbank/librispeech_cuts_${part}.jsonl.gz
done
touch data/fbank/.librispeech-validated.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
files=$(
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
done
fi
# Compile LG for RNN-T fast_beam_search decoding
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile LG"
./local/compile_lg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir
done
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare the other datasets"
# GigaSpeech
if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then
log "Dataset: GigaSpeech"
./prepare_giga_speech.sh --stop_stage 5
fi
# CommonVoice
if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then
log "Dataset: CommonVoice"
./prepare_common_voice.sh
fi
fi

View File

@ -444,9 +444,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -649,13 +646,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -686,7 +677,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -698,7 +688,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -487,9 +487,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -692,13 +689,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -738,7 +729,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -750,7 +740,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -24,7 +24,7 @@ import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
from icefall import NgramLm, NgramLmStateCost from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel from icefall.rnn_lm.model import RnnLmModel
@ -785,6 +785,9 @@ class Hypothesis:
# N-gram LM state # N-gram LM state
state_cost: Optional[NgramLmStateCost] = None state_cost: Optional[NgramLmStateCost] = None
# Context graph state
context_state: Optional[ContextState] = None
@property @property
def key(self) -> str: def key(self) -> str:
"""Return a string representation of self.ys""" """Return a string representation of self.ys"""
@ -937,6 +940,7 @@ def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
context_graph: Optional[ContextGraph] = None,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
blank_penalty: float = 0.0, blank_penalty: float = 0.0,
@ -989,6 +993,7 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=None if context_graph is None else context_graph.root,
timestamp=[], timestamp=[],
) )
) )
@ -1011,6 +1016,7 @@ def modified_beam_search(
hyps_shape = get_hyps_shape(B).to(device) hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B] A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat( ys_log_probs = torch.cat(
@ -1071,21 +1077,51 @@ def modified_beam_search(
for k in range(len(topk_hyp_indexes)): for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k] hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx] hyp = A[i][hyp_idx]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:] new_timestamp = hyp.timestamp[:]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t) new_timestamp.append(t)
if context_graph is not None:
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
new_log_prob = topk_log_probs[k] + context_score
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis( new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
context_state=new_context_state,
) )
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -99,7 +99,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless3/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -125,6 +125,7 @@ For example:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -146,6 +147,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -353,6 +355,27 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest, Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -365,6 +388,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -494,6 +518,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
return_timestamps=True, return_timestamps=True,
) )
else: else:
@ -548,7 +573,12 @@ def decode_one_batch(
return {key: (hyps, timestamps)} return {key: (hyps, timestamps)}
else: else:
return {f"beam_size_{params.beam_size}": (hyps, timestamps)} key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: (hyps, timestamps)}
def decode_dataset( def decode_dataset(
@ -558,6 +588,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
"""Decode dataset. """Decode dataset.
@ -622,6 +653,7 @@ def decode_dataset(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
context_graph=context_graph,
) )
for name, (hyps, timestamps_hyp) in hyps_dict.items(): for name, (hyps, timestamps_hyp) in hyps_dict.items():
@ -728,6 +760,12 @@ def main():
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -750,6 +788,10 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-context-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -881,6 +923,18 @@ def main():
decoding_graph = None decoding_graph = None
word_table = None word_table = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -905,6 +959,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
) )
save_results( save_results(

View File

@ -78,7 +78,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -115,7 +115,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless4/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
@ -37,7 +37,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
@ -195,7 +195,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless4/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -296,7 +296,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )

View File

@ -87,7 +87,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )

View File

@ -84,7 +84,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )

View File

@ -78,7 +78,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -115,7 +115,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless5/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -20,7 +20,7 @@
To run this file, do: To run this file, do:
cd icefall/egs/librispeech/ASR cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless4/test_model.py python ./pruned_transducer_stateless5/test_model.py
""" """
from train import get_params, get_transducer_model from train import get_params, get_transducer_model

View File

@ -328,7 +328,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )

View File

@ -20,23 +20,23 @@
# to a single one using model averaging. # to a single one using model averaging.
""" """
Usage: Usage:
./pruned_transducer_stateless2/export.py \ ./pruned_transducer_stateless6/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless6/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless2/decode.py`, To use the generated file with `pruned_transducer_stateless6/decode.py`,
you can do: you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless6/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless6/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 100 \ --max-duration 100 \
@ -65,7 +65,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -91,7 +91,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless6/exp",
help="""It specifies the directory where all training related help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
""", """,

View File

@ -267,7 +267,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )

View File

@ -94,10 +94,10 @@ Usage:
--max-states 64 --max-states 64
(8) modified beam search with RNNLM shallow fusion (8) modified beam search with RNNLM shallow fusion
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless7/decode.py \
--epoch 35 \ --epoch 35 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search_lm_shallow_fusion \ --decoding-method modified_beam_search_lm_shallow_fusion \
--beam-size 4 \ --beam-size 4 \
@ -110,11 +110,11 @@ Usage:
--rnn-lm-tie-weights 1 --rnn-lm-tie-weights 1
(9) modified beam search with LM shallow fusion + LODR (9) modified beam search with LM shallow fusion + LODR
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless7/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--max-duration 600 \ --max-duration 600 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--decoding-method modified_beam_search_LODR \ --decoding-method modified_beam_search_LODR \
--beam-size 4 \ --beam-size 4 \
--lm-type rnn \ --lm-type rnn \

View File

@ -79,7 +79,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -116,7 +116,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless5/exp", default="pruned_transducer_stateless7/exp",
help="""It specifies the directory where all training related help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
""", """,

View File

@ -389,7 +389,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -620,9 +620,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -896,13 +893,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -953,7 +944,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -966,7 +956,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -1,77 +0,0 @@
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import re
from pathlib import Path
import lhotse
from lhotse import CutSet, load_manifest_lazy
class MultiDataset:
def __init__(self, manifest_dir: str, cv_manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files:
- librispeech_cuts_train-all-shuf.jsonl.gz
- gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz
cv_manifest_dir:
It is expected to contain the following files:
- cv-en_cuts_train.jsonl.gz
"""
self.manifest_dir = Path(manifest_dir)
self.cv_manifest_dir = Path(cv_manifest_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# LibriSpeech
logging.info(f"Loading LibriSpeech in lazy mode")
librispeech_cuts = load_manifest_lazy(
self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
# GigaSpeech
filenames = glob.glob(
f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz"
)
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode")
gigaspeech_cuts = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
# CommonVoice
logging.info(f"Loading CommonVoice in lazy mode")
commonvoice_cuts = load_manifest_lazy(
self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz"
)
return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts)

View File

@ -50,7 +50,6 @@ import copy
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
@ -66,7 +65,6 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from multidataset import MultiDataset
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
@ -90,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
symlink_or_copy,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -341,7 +340,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -376,13 +375,6 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--use-multidataset",
type=str2bool,
default=False,
help="Whether to use multidataset to train.",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -578,9 +570,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -612,7 +601,8 @@ def save_checkpoint(
""" """
if rank != 0: if rank != 0:
return return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" epoch_basename = f"epoch-{params.cur_epoch}.pt"
filename = params.exp_dir / epoch_basename
save_checkpoint_impl( save_checkpoint_impl(
filename=filename, filename=filename,
model=model, model=model,
@ -626,12 +616,14 @@ def save_checkpoint(
) )
if params.best_train_epoch == params.cur_epoch: if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt" symlink_or_copy(
copyfile(src=filename, dst=best_train_filename) exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
)
if params.best_valid_epoch == params.cur_epoch: if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt" symlink_or_copy(
copyfile(src=filename, dst=best_valid_filename) exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
)
def compute_loss( def compute_loss(
@ -811,13 +803,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -864,7 +850,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -877,7 +862,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,
@ -1053,10 +1037,6 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.use_multidataset:
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
train_cuts = multidataset.train_cuts()
else:
if params.mini_libri: if params.mini_libri:
train_cuts = librispeech.train_clean_5_cuts() train_cuts = librispeech.train_clean_5_cuts()
elif params.full_libri: elif params.full_libri:
@ -1118,7 +1098,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.use_multidataset and not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,

View File

@ -346,7 +346,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -577,9 +577,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -830,13 +827,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -883,7 +874,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -896,7 +886,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -342,7 +342,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -570,9 +570,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -819,13 +816,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -872,7 +863,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -885,7 +875,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -90,7 +90,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )

View File

@ -88,7 +88,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )

View File

@ -39,7 +39,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -65,7 +65,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
""", """,

View File

@ -77,7 +77,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -114,7 +114,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless7_streaming/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -355,7 +355,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -586,9 +586,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -807,13 +804,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -860,7 +851,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -873,7 +863,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -355,7 +355,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -587,9 +587,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -808,13 +805,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -861,7 +852,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -874,7 +864,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -88,7 +88,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )

View File

@ -78,7 +78,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -115,7 +115,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless7_streaming_multi/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -366,7 +366,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -604,9 +604,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -921,7 +918,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -934,7 +930,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -348,7 +348,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -1127,7 +1127,16 @@ def run(rank, world_size, args):
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in model.named_parameters()]
)
optimizer = ScaledAdam(
model.parameters(),
lr=params.base_lr,
clipping_scale=2.0,
parameters_names=parameters_names,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)

View File

@ -627,14 +627,6 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = librispeech.train_dataloaders(train_cuts) train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()

View File

@ -654,20 +654,6 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
try:
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts) train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()

View File

@ -642,20 +642,6 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
try:
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts) train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()

View File

@ -0,0 +1,775 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool, make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
self.pad_length = 7 + 2 * 3
def forward(
self,
x: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
N = x.size(0)
T = self.chunk_size * 2 + self.pad_length
x_lens = torch.tensor([T] * N, device=x.device)
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=x,
x_lens=x_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2)
encoder_states = states[:-2]
logging.info(f"len_encoder_states={len(encoder_states)}")
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, new_states
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
states.append(processed_lens)
return states
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
)
decode_chunk_len = encoder_model.chunk_size * 2
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
logging.info(f"len(init_state): {len(init_state)}")
inputs = {}
input_names = ["x"]
outputs = {}
output_names = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
logging.info(f"{name}.shape: {tensors[0].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
logging.info(f"{name}.shape: {tensors[1].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
logging.info(f"{name}.shape: {tensors[2].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
logging.info(f"{name}.shape: {tensors[3].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
logging.info(f"{name}.shape: {tensors[4].shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
logging.info(f"{name}.shape: {tensors[5].shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
ds = encoder_model.encoder.downsampling_factor
left_context_len = encoder_model.left_context_len
left_context_len = [left_context_len // k for k in ds]
left_context_len = ",".join(map(str, left_context_len))
query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim))
value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim))
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "streaming zipformer2",
"decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 32+7+2*3=45
"num_encoder_layers": num_encoder_layers,
"encoder_dims": encoder_dims,
"cnn_module_kernels": cnn_module_kernels,
"left_context_len": left_context_len,
"query_head_dims": query_head_dims,
"value_head_dims": value_head_dims,
"num_heads": num_heads,
}
logging.info(f"meta_data: {meta_data}")
for i in range(len(init_state[:-2]) // 6):
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
embed_states = init_state[-2]
name = "embed_states"
logging.info(f"{name}.shape: {embed_states.shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size,)
processed_lens = init_state[-1]
name = "processed_lens"
logging.info(f"{name}.shape: {processed_lens.shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
logging.info(inputs)
logging.info(outputs)
logging.info(input_names)
logging.info(output_names)
torch.onnx.export(
encoder_model,
(x, init_state),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"x": {0: "N"},
"encoder_out": {0: "N"},
**inputs,
**outputs,
},
)
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,624 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
\
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool, make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "non-streaming zipformer2",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -49,7 +49,7 @@ class Transducer(nn.Module):
encoder: encoder:
It is the transcription network in the paper. Its accepts It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,). `logit_lens` of shape (N,).
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape

View File

@ -0,0 +1,544 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script loads ONNX models exported by ./export-onnx-streaming.py
and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file with the exported ONNX models
./zipformer/onnx_pretrained-streaming.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
"""
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
logging.info(f"encoder_meta={encoder_meta}")
model_type = encoder_meta["model_type"]
assert model_type == "zipformer2", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
num_encoder_layers = encoder_meta["num_encoder_layers"]
encoder_dims = encoder_meta["encoder_dims"]
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
left_context_len = encoder_meta["left_context_len"]
query_head_dims = encoder_meta["query_head_dims"]
value_head_dims = encoder_meta["value_head_dims"]
num_heads = encoder_meta["num_heads"]
def to_int_list(s):
return list(map(int, s.split(",")))
num_encoder_layers = to_int_list(num_encoder_layers)
encoder_dims = to_int_list(encoder_dims)
cnn_module_kernels = to_int_list(cnn_module_kernels)
left_context_len = to_int_list(left_context_len)
query_head_dims = to_int_list(query_head_dims)
value_head_dims = to_int_list(value_head_dims)
num_heads = to_int_list(num_heads)
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"encoder_dims: {encoder_dims}")
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
logging.info(f"left_context_len: {left_context_len}")
logging.info(f"query_head_dims: {query_head_dims}")
logging.info(f"value_head_dims: {value_head_dims}")
logging.info(f"num_heads: {num_heads}")
num_encoders = len(num_encoder_layers)
self.states = []
for i in range(num_encoders):
num_layers = num_encoder_layers[i]
key_dim = query_head_dims[i] * num_heads[i]
embed_dim = encoder_dims[i]
nonlin_attn_head_dim = 3 * embed_dim // 4
value_dim = value_head_dims[i] * num_heads[i]
conv_left_pad = cnn_module_kernels[i] // 2
for layer in range(num_layers):
cached_key = torch.zeros(
left_context_len[i], batch_size, key_dim
).numpy()
cached_nonlin_attn = torch.zeros(
1, batch_size, left_context_len[i], nonlin_attn_head_dim
).numpy()
cached_val1 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_val2 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
self.states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
self.states.append(processed_lens)
self.num_encoders = num_encoders
self.segment = T
self.offset = decode_chunk_len
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def _build_encoder_input_output(
self,
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
encoder_input = {"x": x.numpy()}
encoder_output = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
encoder_input[name] = tensors[0]
encoder_output.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
encoder_input[name] = tensors[1]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
encoder_input[name] = tensors[2]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
encoder_input[name] = tensors[3]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
encoder_input[name] = tensors[4]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
encoder_input[name] = tensors[5]
encoder_output.append(f"new_{name}")
for i in range(len(self.states[:-2]) // 6):
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
name = "embed_states"
embed_states = self.states[-2]
encoder_input[name] = embed_states
encoder_output.append(f"new_{name}")
# (batch_size,)
name = "processed_lens"
processed_lens = self.states[-1]
encoder_input[name] = processed_lens
encoder_output.append(f"new_{name}")
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
self.states = states
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-7)//2+1)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
out = self.encoder.run(encoder_output_names, encoder_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
Returns:
Return the decoded results so far.
"""
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
encoder_out = encoder_out.squeeze(0)
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
chunk = int(1 * sample_rate) # 1 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
encoder_out = model.run_encoder(frames)
hyp, decoder_out = greedy_search(
model,
encoder_out,
context_size,
decoder_out,
hyp,
)
symbol_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/onnx_pretrained.py

View File

@ -26,6 +26,18 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
# a pull request on PyTorch GitHub.
#
# The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace()
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
if not torch.jit.is_tracing():
return torch.logaddexp(x, y)
else:
return (x.exp() + y.exp()).log()
class PiecewiseLinear(object): class PiecewiseLinear(object):
""" """
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
def __float__(self): def __float__(self):
batch_count = self.batch_count batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting(): if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return float(self.default) return float(self.default)
else: else:
ans = self.schedule(self.batch_count) ans = self.schedule(self.batch_count)
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
def softmax(x: Tensor, dim: int): def softmax(x: Tensor, dim: int):
if not x.requires_grad or torch.jit.is_scripting(): if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
return x.softmax(dim=dim) return x.softmax(dim=dim)
return SoftmaxFunction.apply(x, dim) return SoftmaxFunction.apply(x, dim)
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
self.alpha = alpha self.alpha = alpha
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return x return x
return scale_grad(x, self.alpha) return scale_grad(x, self.alpha)
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
def _no_op(x: Tensor) -> Tensor: def _no_op(x: Tensor) -> Tensor:
if (torch.jit.is_scripting()): if torch.jit.is_scripting() or torch.jit.is_tracing():
return x return x
else: else:
# a no-op function that will have a node in the autograd graph, # a no-op function that will have a node in the autograd graph,
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), """Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1). that we approximate closely with x * sigmoid(x-1).
""" """
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0) return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-L activation. """Return Swoosh-L activation.
""" """
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad: if not x.requires_grad:
return k2.swoosh_l_forward(x) return k2.swoosh_l_forward(x)
else: else:
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-R activation. """Return Swoosh-R activation.
""" """
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
if not x.requires_grad: if not x.requires_grad:
return k2.swoosh_r_forward(x) return k2.swoosh_r_forward(x)
else: else:

View File

@ -27,6 +27,7 @@ from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten from scaling import Balancer, Dropout3, ScaleGrad, Whiten
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
model: nn.Module, model: nn.Module,
inplace: bool = False, inplace: bool = False,
is_pnnx: bool = False, is_pnnx: bool = False,
is_onnx: bool = False,
): ):
""" """
Args: Args:
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
If False, the input model is copied and we modify the copied version. If False, the input model is copied and we modify the copied version.
is_pnnx: is_pnnx:
True if we are going to export the model for PNNX. True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return: Return:
Return a model without scaled layers. Return a model without scaled layers.
""" """
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules(): for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity() d[name] = nn.Identity()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items(): for k, v in d.items():
if "." in k: if "." in k:

View File

@ -81,7 +81,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )

View File

@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x) return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate) layerdrop_rate = float(self.layerdrop_rate)
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
x = self.out_norm(x) x = self.out_norm(x)
x = self.dropout(x) x = self.dropout(x)
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2 x_lens = (x_lens - 7) // 2
else: else:
with warnings.catch_warnings(): with warnings.catch_warnings():

View File

@ -62,20 +62,20 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer2
from scaling import ScheduledFloat
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from subsampling import Conv2dSubsampling
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -84,40 +84,38 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
get_parameter_groups_with_lrs
) )
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def get_adjusted_batch_count( def get_adjusted_batch_count(params: AttributeDict) -> float:
params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference # returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count(). # duration. This is for purposes of set_batch_count().
return (params.batch_idx_train * (params.max_duration * params.world_size) / return (
params.ref_duration) params.batch_idx_train
* (params.max_duration * params.world_size)
/ params.ref_duration
)
def set_batch_count( def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
model: Union[nn.Module, DDP], batch_count: float
) -> None:
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
model = model.module model = model.module
for name, module in model.named_modules(): for name, module in model.named_modules():
if hasattr(module, 'batch_count'): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
if hasattr(module, 'name'): if hasattr(module, "name"):
module.name = name module.name = name
@ -154,35 +152,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dim", "--encoder-dim",
type=str, type=str,
default="192,256,384,512,384,256", default="192,256,384,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list." help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--query-head-dim", "--query-head-dim",
type=str, type=str,
default="32", default="32",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list." help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--value-head-dim", "--value-head-dim",
type=str, type=str,
default="12", default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list." help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--pos-head-dim", "--pos-head-dim",
type=str, type=str,
default="4", default="4",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list." help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
) )
parser.add_argument( parser.add_argument(
"--pos-dim", "--pos-dim",
type=int, type=int,
default="48", default="48",
help="Positional-encoding embedding dimension" help="Positional-encoding embedding dimension",
) )
parser.add_argument( parser.add_argument(
@ -190,7 +188,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str, type=str,
default="192,192,256,256,256,192", default="192,192,256,256,256,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim." "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
) )
parser.add_argument( parser.add_argument(
@ -230,7 +228,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
type=str, type=str,
default="16,32,64,-1", default="16,32,64,-1",
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
" Must be just -1 if --causal=False" " Must be just -1 if --causal=False",
) )
parser.add_argument( parser.add_argument(
@ -239,7 +237,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="64,128,256,-1", default="64,128,256,-1",
help="Maximum left-contexts for causal training, measured in frames which will " help="Maximum left-contexts for causal training, measured in frames which will "
"be converted to a number of chunks. If splitting into chunks, " "be converted to a number of chunks. If splitting into chunks, "
"chunk left-context frames will be chosen randomly from this list; else not relevant." "chunk left-context frames will be chosen randomly from this list; else not relevant.",
) )
@ -313,10 +311,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--base-lr", "--base-lr", type=float, default=0.045, help="The base learning rate."
type=float,
default=0.045,
help="The base learning rate."
) )
parser.add_argument( parser.add_argument(
@ -340,15 +335,14 @@ def get_parser():
type=float, type=float,
default=600, default=600,
help="Reference batch duration for purposes of adjusting batch counts for setting various " help="Reference batch duration for purposes of adjusting batch counts for setting various "
"schedules inside the model" "schedules inside the model",
) )
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -371,8 +365,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -415,7 +408,7 @@ def get_parser():
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0. end of each epoch where `xxx` is the epoch number counting from 1.
""", """,
) )
@ -522,7 +515,7 @@ def get_params() -> AttributeDict:
def _to_int_tuple(s: str): def _to_int_tuple(s: str):
return tuple(map(int, s.split(','))) return tuple(map(int, s.split(",")))
def get_encoder_embed(params: AttributeDict) -> nn.Module: def get_encoder_embed(params: AttributeDict) -> nn.Module:
@ -537,7 +530,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
encoder_embed = Conv2dSubsampling( encoder_embed = Conv2dSubsampling(
in_channels=params.feature_dim, in_channels=params.feature_dim,
out_channels=_to_int_tuple(params.encoder_dim)[0], out_channels=_to_int_tuple(params.encoder_dim)[0],
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
) )
return encoder_embed return encoder_embed
@ -596,7 +589,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(','))), encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -667,9 +660,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -748,11 +738,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -782,27 +768,24 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss = ( loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -895,12 +878,11 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
saved_bad_model = False saved_bad_model = False
def save_bad_model(suffix: str = ""): def save_bad_model(suffix: str = ""):
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", save_checkpoint_impl(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model, model=model,
model_avg=model_avg, model_avg=model_avg,
params=params, params=params,
@ -908,14 +890,12 @@ def train_one_epoch(
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler, scaler=scaler,
rank=0) rank=0,
)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params)) set_batch_count(model, get_adjusted_batch_count(params))
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -963,7 +943,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -976,7 +955,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,
@ -998,7 +976,9 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())
@ -1008,8 +988,8 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " + f"lr: {cur_lr:.2e}, "
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
) )
if tb_writer is not None: if tb_writer is not None:
@ -1020,9 +1000,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale", cur_grad_scale, params.batch_idx_train
@ -1039,7 +1017,9 @@ def train_one_epoch(
) )
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None: if tb_writer is not None:
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
@ -1113,12 +1093,10 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], model = DDP(model, device_ids=[rank], find_unused_parameters=True)
find_unused_parameters=True)
optimizer = ScaledAdam( optimizer = ScaledAdam(
get_parameter_groups_with_lrs( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect lr=params.base_lr, # should have no effect
clipping_scale=2.0, clipping_scale=2.0,
) )
@ -1139,7 +1117,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1163,9 +1141,9 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 20.0: if c.duration < 1.0 or c.duration > 20.0:
logging.warning( # logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
) # )
return False return False
# In pruned RNN-T, we require that T >= S # In pruned RNN-T, we require that T >= S
@ -1216,8 +1194,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1338,7 +1315,9 @@ def scan_pessimistic_batches_for_oom(
) )
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main(): def main():

View File

@ -133,6 +133,7 @@ class Zipformer2(EncoderInterface):
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
num_encoder_layers = _to_tuple(num_encoder_layers) num_encoder_layers = _to_tuple(num_encoder_layers)
self.num_encoder_layers = num_encoder_layers
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim) pos_head_dim = _to_tuple(pos_head_dim)
@ -258,7 +259,7 @@ class Zipformer2(EncoderInterface):
if not self.causal: if not self.causal:
return -1, -1 return -1, -1
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.chunk_size) == 1, self.chunk_size assert len(self.chunk_size) == 1, self.chunk_size
chunk_size = self.chunk_size[0] chunk_size = self.chunk_size[0]
else: else:
@ -267,7 +268,7 @@ class Zipformer2(EncoderInterface):
if chunk_size == -1: if chunk_size == -1:
left_context_chunks = -1 left_context_chunks = -1
else: else:
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.left_context_frames) == 1, self.left_context_frames assert len(self.left_context_frames) == 1, self.left_context_frames
left_context_frames = self.left_context_frames[0] left_context_frames = self.left_context_frames[0]
else: else:
@ -301,14 +302,14 @@ class Zipformer2(EncoderInterface):
of frames in `embeddings` before padding. of frames in `embeddings` before padding.
""" """
outputs = [] outputs = []
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
feature_masks = [1.0] * len(self.encoder_dim) feature_masks = [1.0] * len(self.encoder_dim)
else: else:
feature_masks = self.get_feature_masks(x) feature_masks = self.get_feature_masks(x)
chunk_size, left_context_chunks = self.get_chunk_info() chunk_size, left_context_chunks = self.get_chunk_info()
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
# Not support exporting a model for simulating streaming decoding # Not support exporting a model for simulating streaming decoding
attn_mask = None attn_mask = None
else: else:
@ -334,7 +335,7 @@ class Zipformer2(EncoderInterface):
x = self.downsample_output(x) x = self.downsample_output(x)
# class Downsample has this rounding behavior.. # class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2 assert self.output_downsampling_factor == 2
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2 lengths = (x_lens + 1) // 2
else: else:
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -372,7 +373,7 @@ class Zipformer2(EncoderInterface):
# t is frame index, shape (seq_len,) # t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device) t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
# c is chunk index for each frame, shape (seq_len,) # c is chunk index for each frame, shape (seq_len,)
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
c = t // chunk_size c = t // chunk_size
else: else:
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -650,7 +651,7 @@ class Zipformer2EncoderLayer(nn.Module):
) )
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return None return None
batch_size = x.shape[1] batch_size = x.shape[1]
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
@ -695,7 +696,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_orig = src src_orig = src
# dropout rate for non-feedforward submodules # dropout rate for non-feedforward submodules
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
attention_skip_rate = 0.0 attention_skip_rate = 0.0
else: else:
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
@ -713,7 +714,7 @@ class Zipformer2EncoderLayer(nn.Module):
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
selected_attn_weights = attn_weights[0:1] selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
pass pass
elif not self.training and random.random() < float(self.const_attention_rate): elif not self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to # Make attention weights constant. The intention is to
@ -732,7 +733,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0 conv_skip_rate = 0.0
else: else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
@ -740,7 +741,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask), src_key_padding_mask=src_key_padding_mask),
conv_skip_rate) conv_skip_rate)
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
ff2_skip_rate = 0.0 ff2_skip_rate = 0.0
else: else:
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
@ -754,7 +755,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0 conv_skip_rate = 0.0
else: else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
@ -762,7 +763,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask), src_key_padding_mask=src_key_padding_mask),
conv_skip_rate) conv_skip_rate)
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
ff3_skip_rate = 0.0 ff3_skip_rate = 0.0
else: else:
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
@ -968,7 +969,7 @@ class Zipformer2Encoder(nn.Module):
pos_emb = self.encoder_pos(src) pos_emb = self.encoder_pos(src)
output = src output = src
if not torch.jit.is_scripting(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
output = output * feature_mask output = output * feature_mask
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
@ -980,7 +981,7 @@ class Zipformer2Encoder(nn.Module):
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
if not torch.jit.is_scripting(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
output = output * feature_mask output = output * feature_mask
return output return output
@ -1073,7 +1074,7 @@ class BypassModule(nn.Module):
# or (batch_size, num_channels,). This is actually the # or (batch_size, num_channels,). This is actually the
# scale on the non-residual term, so 0 correponds to bypassing # scale on the non-residual term, so 0 correponds to bypassing
# this module. # this module.
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.bypass_scale return self.bypass_scale
else: else:
ans = limit_param_value(self.bypass_scale, ans = limit_param_value(self.bypass_scale,
@ -1229,7 +1230,6 @@ class SimpleDownsample(torch.nn.Module):
d_seq_len = (seq_len + ds - 1) // ds d_seq_len = (seq_len + ds - 1) // ds
# Pad to an exact multiple of self.downsample # Pad to an exact multiple of self.downsample
if seq_len != d_seq_len * ds:
# right-pad src, repeating the last element. # right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
@ -1322,10 +1322,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1: if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
@ -1524,7 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k) attn_scores = torch.matmul(q, k)
use_pos_scores = False use_pos_scores = False
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
# We can't put random.random() in the same line # We can't put random.random() in the same line
use_pos_scores = True use_pos_scores = True
elif not self.training or random.random() >= float(self.pos_emb_skip_rate): elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
@ -1542,6 +1538,16 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# the following .as_strided() expression converts the last axis of pos_scores from relative # the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or # to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be. # not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
(pos_scores.stride(0), (pos_scores.stride(0),
pos_scores.stride(1), pos_scores.stride(1),
@ -1551,7 +1557,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = attn_scores + pos_scores attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
pass pass
elif self.training and random.random() < 0.1: elif self.training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be # This is a harder way of limiting the attention scores to not be
@ -1594,7 +1600,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# half-precision output for backprop purposes. # half-precision output for backprop purposes.
attn_weights = softmax(attn_scores, dim=-1) attn_weights = softmax(attn_scores, dim=-1)
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
pass pass
elif random.random() < 0.001 and not self.training: elif random.random() < 0.001 and not self.training:
self._print_attn_entropy(attn_weights) self._print_attn_entropy(attn_weights)
@ -1672,9 +1678,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.] # [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb) pos_scores = torch.matmul(p, pos_emb)
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(k_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
# the following .as_strided() expression converts the last axis of pos_scores from relative # the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or # to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be. # not, but let this code define which way round it is supposed to be.
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
(pos_scores.stride(0), (pos_scores.stride(0),
pos_scores.stride(1), pos_scores.stride(1),
@ -2136,7 +2153,7 @@ class ConvolutionModule(nn.Module):
if src_key_padding_mask is not None: if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if not torch.jit.is_scripting() and chunk_size >= 0: if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
# Not support exporting a model for simulated streaming decoding # Not support exporting a model for simulated streaming decoding
assert self.causal, "Must initialize model with causal=True if you use chunk_size" assert self.causal, "Must initialize model with causal=True if you use chunk_size"
x = self.depthwise_conv(x, chunk_size=chunk_size) x = self.depthwise_conv(x, chunk_size=chunk_size)

View File

@ -503,9 +503,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -741,13 +738,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -797,7 +788,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -810,7 +800,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -506,9 +506,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -748,15 +745,9 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -805,7 +796,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -818,7 +808,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,155 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file computes fbank features of the MuST-C dataset.
It looks for manifests in the directory "in_dir" and write
generated features to "out_dir".
"""
import argparse
import logging
from pathlib import Path
import torch
from lhotse import (
CutSet,
Fbank,
FbankConfig,
FeatureSet,
LilcomChunkyWriter,
load_manifest,
)
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--in-dir",
type=Path,
required=True,
help="Input manifest directory",
)
parser.add_argument(
"--out-dir",
type=Path,
required=True,
help="Output directory where generated fbank features are saved.",
)
parser.add_argument(
"--tgt-lang",
type=str,
required=True,
help="Target language, e.g., zh, de, fr.",
)
parser.add_argument(
"--num-jobs",
type=int,
default=1,
help="Number of jobs for computing features",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="""True to enable speed perturb with factors 0.9 and 1.1 on
the train subset. False (by default) to disable speed perturb.
""",
)
return parser.parse_args()
def compute_fbank_must_c(
in_dir: Path,
out_dir: Path,
tgt_lang: str,
num_jobs: int,
perturb_speed: bool,
):
out_dir.mkdir(parents=True, exist_ok=True)
extractor = Fbank(FbankConfig(num_mel_bins=80))
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
prefix = "must_c"
suffix = "jsonl.gz"
for p in parts:
logging.info(f"Processing {p}")
cuts_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
if perturb_speed and p == "train":
cuts_path += "_sp"
cuts_path += ".jsonl.gz"
if Path(cuts_path).is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
recordings_filename = in_dir / f"{prefix}_recordings_en-{tgt_lang}_{p}.jsonl.gz"
supervisions_filename = (
in_dir / f"{prefix}_supervisions_en-{tgt_lang}_{p}_norm_rm.jsonl.gz"
)
assert recordings_filename.is_file(), recordings_filename
assert supervisions_filename.is_file(), supervisions_filename
cut_set = CutSet.from_manifests(
recordings=load_manifest(recordings_filename),
supervisions=load_manifest(supervisions_filename),
)
if perturb_speed and p == "train":
logging.info("Speed perturbing for the train dataset")
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}_sp"
else:
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=storage_path,
num_jobs=num_jobs,
storage_type=LilcomChunkyWriter,
)
logging.info("About to split cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main():
args = get_args()
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
logging.info(vars(args))
assert args.in_dir.is_dir(), args.in_dir
compute_fbank_must_c(
in_dir=args.in_dir,
out_dir=args.out_dir,
tgt_lang=args.tgt_lang,
num_jobs=args.num_jobs,
perturb_speed=args.perturb_speed,
)
if __name__ == "__main__":
main()

34
egs/must_c/ST/local/get_text.py Executable file
View File

@ -0,0 +1,34 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file prints the text field of supervisions from cutset to the console
"""
import argparse
from lhotse import load_manifest_lazy
from pathlib import Path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Input manifest",
)
return parser.parse_args()
def main():
args = get_args()
assert args.manifest.is_file(), args.manifest
cutset = load_manifest_lazy(args.manifest)
for c in cutset:
for sup in c.supervisions:
print(sup.text)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file generates words.txt from the given transcript file.
"""
import argparse
from pathlib import Path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"transcript",
type=Path,
help="Input transcript file",
)
return parser.parse_args()
def main():
args = get_args()
assert args.transcript.is_file(), args.transcript
word_set = set()
with open(args.transcript) as f:
for line in f:
words = line.strip().split()
for w in words:
word_set.add(w)
# Note: reserved* should be kept in sync with ./local/prepare_lang_bpe.py
reserved1 = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>"]
reserved2 = ["#0", "<s>", "</s>"]
for w in reserved1 + reserved2:
assert w not in word_set, w
words = sorted(list(word_set))
words = reserved1 + words + reserved2
for i, w in enumerate(words):
print(w, i)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,169 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import re
def normalize_punctuation(s: str, lang: str) -> str:
"""
This function implements
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/normalize-punctuation.perl
Args:
s:
A string to be normalized.
lang:
The language to which `s` belongs
Returns:
Return a normalized string.
"""
# s/\r//g;
s = re.sub("\r", "", s)
# remove extra spaces
# s/\(/ \(/g;
s = re.sub("\(", " (", s) # add a space before (
# s/\)/\) /g; s/ +/ /g;
s = re.sub("\)", ") ", s) # add a space after )
s = re.sub(" +", " ", s) # convert multiple spaces to one
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
s = re.sub("\) ([\.\!\:\?\;\,])", r")\1", s)
# s/\( /\(/g;
s = re.sub("\( ", "(", s) # remove space after (
# s/ \)/\)/g;
s = re.sub(" \)", ")", s) # remove space before )
# s/(\d) \%/$1\%/g;
s = re.sub("(\d) \%", r"\1%", s) # remove space between a digit and %
# s/ :/:/g;
s = re.sub(" :", ":", s) # remove space before :
# s/ ;/;/g;
s = re.sub(" ;", ";", s) # remove space before ;
# normalize unicode punctuation
# s/\`/\'/g;
s = re.sub("`", "'", s) # replace ` with '
# s/\'\'/ \" /g;
s = re.sub("''", '"', s) # replace '' with "
# s/„/\"/g;
s = re.sub("", '"', s) # replace „ with "
# s/“/\"/g;
s = re.sub("", '"', s) # replace “ with "
# s/”/\"/g;
s = re.sub("", '"', s) # replace ” with "
# s//-/g;
s = re.sub("", "-", s) # replace with -
# s/—/ - /g; s/ +/ /g;
s = re.sub("", " - ", s)
s = re.sub(" +", " ", s) # convert multiple spaces to one
# s/´/\'/g;
s = re.sub("´", "'", s)
# s/([a-z])([a-z])/$1\'$2/gi;
s = re.sub("([a-z])([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
# s/([a-z])([a-z])/$1\'$2/gi;
s = re.sub("([a-z])([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
# s//\'/g;
s = re.sub("", "'", s)
# s//\'/g;
s = re.sub("", "'", s)
# s//\"/g;
s = re.sub("", '"', s)
# s/''/\"/g;
s = re.sub("''", '"', s)
# s/´´/\"/g;
s = re.sub("´´", '"', s)
# s/…/.../g;
s = re.sub("", "...", s)
# French quotes
# s/ « / \"/g;
s = re.sub(" « ", ' "', s)
# s/« /\"/g;
s = re.sub("« ", '"', s)
# s/«/\"/g;
s = re.sub("«", '"', s)
# s/ » /\" /g;
s = re.sub(" » ", '" ', s)
# s/ »/\"/g;
s = re.sub(" »", '"', s)
# s/»/\"/g;
s = re.sub("»", '"', s)
# handle pseudo-spaces
# s/ \%/\%/g;
s = re.sub(" %", r"%", s)
# s/nº /nº /g;
s = re.sub(" ", "", s)
# s/ :/:/g;
s = re.sub(" :", ":", s)
# s/ ºC/ ºC/g;
s = re.sub(" ºC", " ºC", s)
# s/ cm/ cm/g;
s = re.sub(" cm", " cm", s)
# s/ \?/\?/g;
s = re.sub(" \?", "\?", s)
# s/ \!/\!/g;
s = re.sub(" \!", "\!", s)
# s/ ;/;/g;
s = re.sub(" ;", ";", s)
# s/, /, /g; s/ +/ /g;
s = re.sub(", ", ", ", s)
s = re.sub(" +", " ", s)
if lang == "en":
# English "quotation," followed by comma, style
# s/\"([,\.]+)/$1\"/g;
s = re.sub('"([,\.]+)', r'\1"', s)
elif lang in ("cs", "cz"):
# Czech is confused
pass
else:
# German/Spanish/French "quotation", followed by comma, style
# s/,\"/\",/g;
s = re.sub(',"', '",', s)
# s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence
s = re.sub('(\.+)"(\s*[^<])', r'"\1\2', s)
if lang in ("de", "es", "cz", "cs", "fr"):
# s/(\d) (\d)/$1,$2/g;
s = re.sub("(\d) (\d)", r"\1,\2", s)
else:
# s/(\d) (\d)/$1.$2/g;
s = re.sub("(\d) (\d)", r"\1.\2", s)
return s

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script normalizes transcripts from supervisions.
Usage:
./local/preprocess_must_c.py \
--manifest-dir ./data/manifests/v1.0/ \
--tgt-lang de
"""
import argparse
import logging
import re
from functools import partial
from pathlib import Path
from lhotse.recipes.utils import read_manifests_if_cached
from normalize_punctuation import normalize_punctuation
from remove_non_native_characters import remove_non_native_characters
from remove_punctuation import remove_punctuation
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-dir",
type=Path,
required=True,
help="Manifest directory",
)
parser.add_argument(
"--tgt-lang",
type=str,
required=True,
help="Target language, e.g., zh, de, fr.",
)
return parser.parse_args()
def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
remove_non_native_characters_lang = partial(
remove_non_native_characters, lang=tgt_lang
)
prefix = "must_c"
suffix = "jsonl.gz"
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
for p in parts:
logging.info(f"Processing {p}")
name = f"en-{tgt_lang}_{p}"
# norm: normalization
# rm: remove punctuation
dst_name = manifest_dir / f"must_c_supervisions_{name}_norm_rm.jsonl.gz"
if dst_name.is_file():
logging.info(f"{dst_name} exists - skipping")
continue
manifests = read_manifests_if_cached(
dataset_parts=name,
output_dir=manifest_dir,
prefix=prefix,
suffix=suffix,
types=("supervisions",),
)
if name not in manifests:
raise RuntimeError(f"Processing {p} failed.")
supervisions = manifests[name]["supervisions"]
supervisions = supervisions.transform_text(normalize_punctuation_lang)
supervisions = supervisions.transform_text(remove_punctuation)
supervisions = supervisions.transform_text(lambda x: x.lower())
supervisions = supervisions.transform_text(remove_non_native_characters_lang)
supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x))
supervisions.to_file(dst_name)
def main():
args = get_args()
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
logging.info(vars(args))
assert args.manifest_dir.is_dir(), args.manifest_dir
preprocess_must_c(
manifest_dir=args.manifest_dir,
tgt_lang=args.tgt_lang,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,21 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import re
def remove_non_native_characters(s: str, lang: str):
if lang == "de":
# ä -> ae
# ö -> oe
# ü -> ue
# ß -> ss
s = re.sub("ä", "ae", s)
s = re.sub("ö", "oe", s)
s = re.sub("ü", "ue", s)
s = re.sub("ß", "ss", s)
# keep only a-z and spaces
# note: ' is removed
s = re.sub(r"[^a-z\s]", "", s)
return s

View File

@ -0,0 +1,41 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import re
import string
def remove_punctuation(s: str) -> str:
"""
It implements https://github.com/espnet/espnet/blob/master/utils/remove_punctuation.pl
"""
# Remove punctuation except apostrophe
# s/<space>/spacemark/g; # for scoring
s = re.sub("<space>", "spacemark", s)
# s/'/apostrophe/g;
s = re.sub("'", "apostrophe", s)
# s/[[:punct:]]//g;
s = s.translate(str.maketrans("", "", string.punctuation))
# string punctuation returns the following string
# !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
# See
# https://stackoverflow.com/questions/265960/best-way-to-strip-punctuation-from-a-string
# s/apostrophe/'/g;
s = re.sub("apostrophe", "'", s)
# s/spacemark/<space>/g; # for scoring
s = re.sub("spacemark", "<space>", s)
# remove whitespace
# s/\s+/ /g;
s = re.sub("\s+", " ", s)
# s/^\s+//;
s = re.sub("^\s+", "", s)
# s/\s+$//;
s = re.sub("\s+$", "", s)
return s

View File

@ -0,0 +1,197 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
from normalize_punctuation import normalize_punctuation
def test_normalize_punctuation():
# s/\r//g;
s = "a\r\nb\r\n"
n = normalize_punctuation(s, lang="en")
assert "\r" not in n
assert len(s) - 2 == len(n), (len(s), len(n))
# s/\(/ \(/g;
s = "(ab (c"
n = normalize_punctuation(s, lang="en")
assert n == " (ab (c", n
# s/\)/\) /g;
s = "a)b c)"
n = normalize_punctuation(s, lang="en")
assert n == "a) b c) "
# s/ +/ /g;
s = " a b c d "
n = normalize_punctuation(s, lang="en")
assert n == " a b c d "
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
for i in ".!:?;,":
s = f"a) {i}"
n = normalize_punctuation(s, lang="en")
assert n == f"a){i}"
# s/\( /\(/g;
s = "a( b"
n = normalize_punctuation(s, lang="en")
assert n == "a (b", n
# s/ \)/\)/g;
s = "ab ) a"
n = normalize_punctuation(s, lang="en")
assert n == "ab) a", n
# s/(\d) \%/$1\%/g;
s = "1 %a"
n = normalize_punctuation(s, lang="en")
assert n == "1%a", n
# s/ :/:/g;
s = "a :"
n = normalize_punctuation(s, lang="en")
assert n == "a:", n
# s/ ;/;/g;
s = "a ;"
n = normalize_punctuation(s, lang="en")
assert n == "a;", n
# s/\`/\'/g;
s = "`a`"
n = normalize_punctuation(s, lang="en")
assert n == "'a'", n
# s/\'\'/ \" /g;
s = "''a''"
n = normalize_punctuation(s, lang="en")
assert n == '"a"', n
# s/„/\"/g;
s = '„a"'
n = normalize_punctuation(s, lang="en")
assert n == '"a"', n
# s/“/\"/g;
s = "“a„"
n = normalize_punctuation(s, lang="en")
assert n == '"a"', n
# s/”/\"/g;
s = "“a”"
n = normalize_punctuation(s, lang="en")
assert n == '"a"', n
# s//-/g;
s = "ab"
n = normalize_punctuation(s, lang="en")
assert n == "a-b", n
# s/—/ - /g; s/ +/ /g;
s = "a—b"
n = normalize_punctuation(s, lang="en")
assert n == "a - b", n
# s/´/\'/g;
s = "a´b"
n = normalize_punctuation(s, lang="en")
assert n == "a'b", n
# s/([a-z])([a-z])/$1\'$2/gi;
for i in "":
s = f"a{i}B"
n = normalize_punctuation(s, lang="en")
assert n == "a'B", n
s = f"A{i}B"
n = normalize_punctuation(s, lang="en")
assert n == "A'B", n
s = f"A{i}b"
n = normalize_punctuation(s, lang="en")
assert n == "A'b", n
# s//\'/g;
# s//\'/g;
for i in "":
s = f"a{i}b"
n = normalize_punctuation(s, lang="en")
assert n == "a'b", n
# s//\"/g;
s = ""
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/''/\"/g;
s = "''"
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/´´/\"/g;
s = "´´"
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/…/.../g;
s = ""
n = normalize_punctuation(s, lang="en")
assert n == "...", n
# s/ « / \"/g;
s = "a « b"
n = normalize_punctuation(s, lang="en")
assert n == 'a "b', n
# s/« /\"/g;
s = "a « b"
n = normalize_punctuation(s, lang="en")
assert n == 'a "b', n
# s/«/\"/g;
s = "a«b"
n = normalize_punctuation(s, lang="en")
assert n == 'a"b', n
# s/ » /\" /g;
s = " » "
n = normalize_punctuation(s, lang="en")
assert n == '" ', n
# s/ »/\"/g;
s = " »"
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/»/\"/g;
s = "»"
n = normalize_punctuation(s, lang="en")
assert n == '"', n
# s/ \%/\%/g;
s = " %"
n = normalize_punctuation(s, lang="en")
assert n == "%", n
# s/ :/:/g;
s = " :"
n = normalize_punctuation(s, lang="en")
assert n == ":", n
# s/(\d) (\d)/$1.$2/g;
s = "2 3"
n = normalize_punctuation(s, lang="en")
assert n == "2.3", n
# s/(\d) (\d)/$1,$2/g;
s = "2 3"
n = normalize_punctuation(s, lang="de")
assert n == "2,3", n
def main():
test_normalize_punctuation()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,26 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
from remove_non_native_characters import remove_non_native_characters
def test_remove_non_native_characters():
s = "Ich heiße xxx好的01 fangjun".lower()
n = remove_non_native_characters(s, lang="de")
assert n == "ich heisse xxx fangjun", n
s = "äÄ".lower()
n = remove_non_native_characters(s, lang="de")
assert n == "aeae", n
s = "öÖ".lower()
n = remove_non_native_characters(s, lang="de")
assert n == "oeoe", n
s = "üÜ".lower()
n = remove_non_native_characters(s, lang="de")
assert n == "ueue", n
if __name__ == "__main__":
test_remove_non_native_characters()

View File

@ -0,0 +1,17 @@
#!/usr/bin/env python3
from remove_punctuation import remove_punctuation
def test_remove_punctuation():
s = "a,b'c!#"
n = remove_punctuation(s)
assert n == "ab'c", n
s = " ab " # remove leading and trailing spaces
n = remove_punctuation(s)
assert n == "ab", n
if __name__ == "__main__":
test_remove_punctuation()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/validate_bpe_lexicon.py

173
egs/must_c/ST/prepare.sh Executable file
View File

@ -0,0 +1,173 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=10
stage=0
stop_stage=100
version=v1.0
tgt_lang=de
dl_dir=$PWD/download
must_c_dir=$dl_dir/must-c/$version/en-$tgt_lang/data
# We assume dl_dir (download dir) contains the following
# directories and files.
# - $dl_dir/must-c/$version/en-$tgt_lang/data/{dev,train,tst-COMMON,tst-HE}
#
# Please go to https://ict.fbk.eu/must-c-releases/
# to download and untar the dataset if you have not already done this.
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate
# data/lang_bpe_${tgt_lang}_xxx
# data/lang_bpe_${tgt_lang}_yyy
# if the array contains xxx, yyy
vocab_sizes=(
# 5000
# 2000
# 1000
500
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ ! -d $must_c_dir ]; then
log "$must_c_dir does not exist"
exit 1
fi
for d in dev train tst-COMMON tst-HE; do
if [ ! -d $must_c_dir/$d ]; then
log "$must_c_dir/$d does not exist!"
exit 1
fi
done
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download musan"
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to $dl_dir/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare must-c $version manifest for target language $tgt_lang"
mkdir -p data/manifests/$version
if [ ! -e data/manifests/$version/.${tgt_lang}.manifests.done ]; then
lhotse prepare must-c \
-j $nj \
--tgt-lang $tgt_lang \
$dl_dir/must-c/$version/ \
data/manifests/$version/
touch data/manifests/$version/.${tgt_lang}.manifests.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Text normalization for $version with target language $tgt_lang"
if [ ! -f ./data/manifests/$version/.$tgt_lang.norm.done ]; then
./local/preprocess_must_c.py \
--manifest-dir ./data/manifests/$version/ \
--tgt-lang $tgt_lang
touch ./data/manifests/$version/.$tgt_lang.norm.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for $version with target language $tgt_lang"
mkdir -p data/fbank/$version/
if [ ! -e data/fbank/$version/.$tgt_lang.done ]; then
./local/compute_fbank_must_c.py \
--in-dir ./data/manifests/$version/ \
--out-dir ./data/fbank/$version/ \
--tgt-lang $tgt_lang \
--num-jobs $nj
./local/compute_fbank_must_c.py \
--in-dir ./data/manifests/$version/ \
--out-dir ./data/fbank/$version/ \
--tgt-lang $tgt_lang \
--num-jobs $nj \
--perturb-speed 1
touch data/fbank/$version/.$tgt_lang.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang for $version with target language $tgt_lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}/$version/$tgt_lang/
mkdir -p $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
./local/get_text.py ./data/fbank/$version/must_c_feats_en-${tgt_lang}_train.jsonl.gz > $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/words.txt ]; then
./local/get_words.py $lang_dir/transcript_words.txt > $lang_dir/words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
done
fi

1
egs/must_c/ST/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,154 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from datetime import datetime
from pathlib import Path
import torch
from lhotse import (
CutSet,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
set_audio_duration_mismatch_tolerance,
set_caching_enabled,
)
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-workers",
type=int,
default=20,
help="Number of dataloading workers used for reading the audio.",
)
parser.add_argument(
"--batch-duration",
type=float,
default=600.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--num-splits",
type=int,
required=True,
help="The number of splits of the train subset",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
return parser.parse_args()
def compute_fbank_peoples_speech_splits(args):
subsets = ("dirty", "dirty_sa", "clean", "clean_sa")
num_splits = args.num_splits
output_dir = f"data/fbank/peoples_speech_train_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
num_digits = 8
start = args.start
stop = args.stop
if stop < start:
stop = num_splits
stop = min(stop, num_splits)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
set_caching_enabled(False)
for partition in subsets:
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
logging.info(f"Processing {partition}: {idx}")
cuts_path = output_dir / f"peoples_speech_cuts_{partition}.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = (
output_dir / f"peoples_speech_cuts_{partition}_raw.{idx}.jsonl.gz"
)
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Splitting cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/peoples_speech_feats_{partition}_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
compute_fbank_peoples_speech_splits(args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,93 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the People's Speech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from filter_cuts import filter_cuts
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_peoples_speech_valid_test():
src_dir = Path(f"data/manifests")
output_dir = Path(f"data/fbank")
num_workers = 42
batch_duration = 600
subsets = ("validation", "test")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
for partition in subsets:
cuts_path = output_dir / f"peoples_speech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{partition} already exists - skipping.")
continue
raw_cuts_path = output_dir / f"peoples_speech_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/peoples_speech_feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_peoples_speech_valid_test()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/filter_cuts.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
from pathlib import Path
from typing import Optional
from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
return parser.parse_args()
def normalize_text(utt: str) -> str:
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
def preprocess_peoples_speech(dataset: Optional[str] = None):
src_dir = Path(f"data/manifests")
output_dir = Path(f"data/fbank")
output_dir.mkdir(exist_ok=True)
if dataset is None:
dataset_parts = (
"validation",
"test",
"dirty",
"dirty_sa",
"clean",
"clean_sa",
)
else:
dataset_parts = dataset.split(" ", -1)
logging.info("Loading manifest, it may takes 8 minutes")
prefix = f"peoples_speech"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
suffix=suffix,
prefix=prefix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
logging.info(f"Normalizing text in {partition}")
i = 0
for sup in m["supervisions"]:
text = str(sup.text)
orig_text = text
sup.text = normalize_text(sup.text)
text = str(sup.text)
if i < 10 and len(orig_text) != len(text):
logging.info(
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
)
i += 1
# Create long-recording cut manifests.
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
).resample(16000)
# Run data augmentation that needs to be done in the
# time domain.
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
preprocess_peoples_speech(dataset=args.dataset)
logging.info("Done")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/validate_bpe_lexicon.py

247
egs/peoples_speech/ASR/prepare.sh Executable file
View File

@ -0,0 +1,247 @@
#!/usr/bin/env bash
set -eou pipefail
nj=32
stage=-1
stop_stage=100
# Split data/set to a number of pieces
# This is to avoid OOM during feature extraction.
num_per_split=4000
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/peoples_speech
# This directory contains the following files downloaded from
# https://huggingface.co/datasets/MLCommons/peoples_speech
#
# - test
# - train
# - validation
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
# 5000
# 2000
# 1000
500
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/peoples_speech,
# you can create a symlink
#
# ln -sfv /path/to/peoples_speech $dl_dir/peoples_speech
#
if [ ! -d $dl_dir/peoples_speech/train ]; then
git lfs install
git clone https://huggingface.co/datasets/MLCommons/peoples_speech
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare People's Speech manifest"
# We assume that you have downloaded the People's Speech corpus
# to $dl_dir/peoples_speech
mkdir -p data/manifests
if [ ! -e data/manifests/.peoples_speech.done ]; then
lhotse prepare peoples-speech -j $nj $dl_dir/peoples_speech data/manifests
touch data/manifests/.peoples_speech.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Preprocess People's Speech manifest"
mkdir -p data/fbank
if [ ! -e data/fbank/.preprocess_complete ]; then
./local/preprocess_peoples_speech.py
touch data/fbank/.preprocess_complete
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for valid and test subsets of People's Speech"
if [ ! -e data/fbank/.peoples_speech_valid_test.done ]; then
./local/compute_fbank_peoples_speech_valid_test.py
touch data/fbank/.peoples_speech_valid_test.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Split train subset into pieces"
split_dir=data/fbank/peoples_speech_train_split
if [ ! -e $split_dir/.peoples_speech_dirty_split.done ]; then
lhotse split-lazy ./data/fbank/peoples_speech_cuts_dirty_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.peoples_speech_dirty_split.done
fi
if [ ! -e $split_dir/.peoples_speech_dirty_sa_split.done ]; then
lhotse split-lazy ./data/fbank/peoples_speech_cuts_dirty_sa_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.peoples_speech_dirty_sa_split.done
fi
if [ ! -e $split_dir/.peoples_speech_clean_split.done ]; then
lhotse split-lazy ./data/fbank/peoples_speech_cuts_clean_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.peoples_speech_clean_split.done
fi
if [ ! -e $split_dir/.peoples_speech_clean_sa_split.done ]; then
lhotse split-lazy ./data/fbank/peoples_speech_cuts_clean_sa_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.peoples_speech_clean_sa_split.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Compute features for train subset of People's Speech"
if [ ! -e data/fbank/.peoples_speech_train.done ]; then
./local/compute_fbank_peoples_speech_splits.py \
--num-workers $nj \
--batch-duration 600 \
--start 0 \
--num-splits 2000
touch data/fbank/.peoples_speech_train.done
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
file=$(
find "data/fbank/peoples_speech_cuts_dirty_raw.jsonl.gz"
find "data/fbank/peoples_speech_cuts_dirty_sa_raw.jsonl.gz"
find "data/fbank/peoples_speech_cuts_clean_raw.jsonl.gz"
find "data/fbank/peoples_speech_cuts_clean_sa_raw.jsonl.gz"
)
gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt
# Ensure space only appears once
sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
sed -i 's/ +/ /g' $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/words.txt ]; then
cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_dir/words.txt
(echo '!SIL'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
cat - $lang_dir/words.txt | sort | uniq | awk '
BEGIN {
print "<eps> 0";
}
{
if ($1 == "<s>") {
print "<s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
if ($1 == "</s>") {
print "</s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $lang_dir/words || exit 1;
mv $lang_dir/words $lang_dir/words.txt
fi
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
log "Validating $lang_dir/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model
fi
if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi
if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
done
fi

View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -443,9 +443,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -648,13 +645,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -685,7 +676,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -697,7 +687,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -513,9 +513,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -726,13 +723,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
# print(batch["supervisions"]) # print(batch["supervisions"])
@ -775,7 +766,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -788,7 +778,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -566,9 +566,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -798,13 +795,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -851,7 +842,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -864,7 +854,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,

View File

@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule:
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=30,
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
@ -364,7 +364,7 @@ class WenetSpeechAsrDataModule:
return valid_dl return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats

View File

@ -112,8 +112,10 @@ When training with the L subset, the streaming usage:
import argparse import argparse
import glob
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -133,7 +135,8 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -307,6 +310,26 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument( parser.add_argument(
"--use-shallow-fusion", "--use-shallow-fusion",
type=str2bool, type=str2bool,
@ -362,6 +385,7 @@ def decode_one_batch(
lexicon: Lexicon, lexicon: Lexicon,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
ngram_lm: Optional[NgramLm] = None, ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0, ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
@ -402,14 +426,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context feature_lens += params.left_context
feature = torch.nn.functional.pad( feature = torch.nn.functional.pad(
feature, feature,
pad=(0, 0, 0, params.left_context), pad=(0, 0, 0, params.left_context),
value=LOG_EPS, value=LOG_EPS,
) )
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
@ -448,6 +471,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, beam=params.beam_size,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
context_graph=context_graph,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -509,7 +533,12 @@ def decode_one_batch(
): hyps ): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} key = f"beam_size_{params.beam_size}"
if params.has_contexts:
key += f"-context-score-{params.context_score}"
else:
key += "-no-context-words"
return {key: hyps}
def decode_dataset( def decode_dataset(
@ -518,6 +547,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
ngram_lm: Optional[NgramLm] = None, ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0, ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
@ -567,6 +597,7 @@ def decode_dataset(
lexicon=lexicon, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
context_graph=context_graph,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,
@ -646,6 +677,12 @@ def main():
"modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR", "modified_beam_search_LODR",
) )
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
@ -655,6 +692,10 @@ def main():
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += "-no-contexts-words"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -684,11 +725,15 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
# import pdb; pdb.set_trace()
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
if params.simulate_streaming: if params.simulate_streaming:
assert ( assert (
params.causal_convolution params.causal_convolution
@ -816,6 +861,19 @@ def main():
else: else:
decoding_graph = None decoding_graph = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts_text = []
for line in open(params.context_file).readlines():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -833,15 +891,16 @@ def main():
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl] test_dls = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,

Some files were not shown because too many files have changed in this diff Show More