Merge 4858e2b0367a7ebaf20641743d91a745777aca63 into abd9437e6d5419a497707748eb935e50976c3b7b

This commit is contained in:
Yifan Yang 2025-06-27 11:33:11 +00:00 committed by GitHub
commit 2c7dcd65f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 513 additions and 117 deletions

View File

@ -388,16 +388,14 @@ class GigaSpeechAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info(f"About to get train_{self.args.subset} cuts") logging.info(f"About to get train_{self.args.subset} cuts")
path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
cuts_train = CutSet.from_jsonl_lazy(path) cuts_train = CutSet.from_jsonl_lazy(path)
return cuts_train return cuts_train
@lru_cache() @lru_cache()
def dev_cuts(self) -> CutSet: def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy( cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
if self.args.small_dev: if self.args.small_dev:
return cuts_valid.subset(first=1000) return cuts_valid.subset(first=1000)
else: else:
@ -406,6 +404,4 @@ class GigaSpeechAsrDataModule:
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)

View File

@ -24,8 +24,7 @@ Usage:
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended)
(2) beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
@ -33,7 +32,6 @@ Usage:
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
@ -42,17 +40,60 @@ Usage:
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (one best)
(4) fast beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) fast beam search (nbest with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -69,6 +110,9 @@ import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
@ -83,6 +127,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import UniqLexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -128,7 +173,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=True,
help="Whether to load averaged model. Currently it only supports " help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
@ -146,10 +191,17 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default=None,
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_phone",
help="The lang dir contains word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -159,6 +211,20 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
""",
)
parser.add_argument(
"--metrics",
type=str,
default="WER",
help="""Possible values are:
- WER
- PER
""", """,
) )
@ -174,27 +240,45 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""", Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG or
fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle
""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle
""",
) )
parser.add_argument( parser.add_argument(
@ -203,6 +287,7 @@ def get_parser():
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
@ -211,6 +296,24 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
return parser return parser
@ -229,7 +332,9 @@ def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -249,13 +354,19 @@ def decode_one_batch(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
pl:
The phone lexicon.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle
and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -273,7 +384,10 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -283,6 +397,58 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
if params.decoding_method == "fast_beam_search":
if sp is not None:
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
@ -291,8 +457,12 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
if sp is not None:
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -300,8 +470,12 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
if sp is not None:
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -325,18 +499,24 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
if sp is not None:
hyps.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp).split())
else:
hyps.append([str(i) for i in hyp])
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif "fast_beam_search" in params.decoding_method:
return { key = f"beam_{params.beam}_"
( key += f"max_contexts_{params.max_contexts}_"
f"beam_{params.beam}_" key += f"max_states_{params.max_states}"
f"max_contexts_{params.max_contexts}_" if "nbest" in params.decoding_method:
f"max_states_{params.max_states}" key += f"_num_paths_{params.num_paths}_"
): hyps key += f"nbest_scale_{params.nbest_scale}"
} if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -346,6 +526,8 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -359,9 +541,15 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
pl:
The phone lexicon.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle
and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -376,10 +564,14 @@ def decode_dataset(
except TypeError: except TypeError:
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
if sp is not None:
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
@ -387,6 +579,8 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -399,6 +593,53 @@ def decode_dataset(
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
else:
if params.metrics == "WER":
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
elif params.metrics == "PER":
texts = batch["supervisions"]["text"]
token_ids = pl.texts_to_token_ids(texts).tolist()
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(token_ids)
for cut_id, hyp_id, ref_token_id in zip(cut_ids, hyps, token_ids):
ref_token_id = [str(i) for i in ref_token_id]
this_batch.append((cut_id, ref_token_id, hyp_id))
results[name].extend(this_batch)
num_cuts += len(texts) num_cuts += len(texts)
@ -414,6 +655,7 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
if params.metrics == "WER":
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
@ -447,6 +689,40 @@ def save_results(
note = "" note = ""
logging.info(s) logging.info(s)
elif params.metrics == "PER":
test_set_pers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out PERs, per-phone error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
per = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_pers[key] = per
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_pers = sorted(test_set_pers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"per-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tPER", file=f)
for key, val in test_set_pers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, PER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_pers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad() @torch.no_grad()
def main(): def main():
@ -458,12 +734,32 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if params.bpe_model is not None:
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
else:
if params.metrics == "PER":
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
), "Decoding method without L or LG must use PER"
elif params.metrics == "WER":
assert params.decoding_method in (
"fast_beam_search_LG",
"fast_beam_search_LG",
"fast_beam_search_nbest_LG",
), "Decoding method with L or LG must use WER"
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -475,8 +771,13 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -493,13 +794,19 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
if params.bpe_model is not None:
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
else:
pl = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(pl.tokens) + 1
sp = None
logging.info(params) logging.info(params)
@ -586,10 +893,24 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
if (
params.decoding_method == "fast_beam_search_LG"
or params.decoding_method == "fast_beam_search_nbest_LG"
):
word_table = pl.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -613,6 +934,8 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang # Wei Kang,
# Mingshuang Luo) # Mingshuang Luo,
# Yifan Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -19,13 +20,15 @@
""" """
Usage: Usage:
(1) bpe
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \ ./pruned_transducer_stateless2/train.py \
--world-size 8 \ --world-size 8 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--max-duration 120 --max-duration 120
# For mix precision training: # For mix precision training:
@ -33,11 +36,37 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \ ./pruned_transducer_stateless2/train.py \
--world-size 8 \ --world-size 8 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--use_fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--max-duration 200 --max-duration 200
(2) phone
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--lang-type phone \
--context-size 4 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--lang-type phone \
--context-size 4 \
--max-duration 750
""" """
@ -77,6 +106,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import UniqLexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -119,8 +149,8 @@ def get_parser():
"--start-epoch", "--start-epoch",
type=int, type=int,
default=1, default=1,
help="""Resume training from this epoch. help="""Resume training from from this epoch.
If larger than 1, it will load checkpoint from If it is large than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt exp-dir/epoch-{start_epoch-1}.pt
""", """,
) )
@ -144,6 +174,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--lang-type",
type=str,
default="bpe",
help="Either bpe or phone",
)
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
@ -151,6 +188,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_phone",
help="the lang dir contains lexicon",
)
parser.add_argument( parser.add_argument(
"--initial-lr", "--initial-lr",
type=float, type=float,
@ -231,7 +275,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=8000, default=20000,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -244,7 +288,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--keep-last-k", "--keep-last-k",
type=int, type=int,
default=30, default=20,
help="""Only keep this number of checkpoints on disk. help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`. in the exp-dir with filenames `checkpoint-xxx.pt`.
@ -519,6 +563,7 @@ def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float = 1.0, warmup: float = 1.0,
@ -551,8 +596,12 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
if sp is not None:
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
else:
y = pl.texts_to_token_ids(texts).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss = model(
@ -592,6 +641,7 @@ def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> MetricsTracker: ) -> MetricsTracker:
@ -605,6 +655,7 @@ def compute_validation_loss(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
batch=batch, batch=batch,
is_training=False, is_training=False,
) )
@ -628,6 +679,7 @@ def train_one_epoch(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
@ -679,8 +731,8 @@ def train_one_epoch(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
sp=sp, sp=sp,
pl=pl,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step), warmup=(params.batch_idx_train / params.model_warm_step),
@ -699,6 +751,17 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 30: if params.print_diagnostics and batch_idx == 30:
return return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if ( if (
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
@ -746,6 +809,7 @@ def train_one_epoch(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
@ -795,12 +859,21 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
if params.lang_type == "bpe":
logging.info(f"Using bpe model")
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
pl = None
elif params.lang_type == "phone":
logging.info(f"Using phone lexion")
pl = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(pl.tokens) + 1
sp = None
logging.info(params) logging.info(params)
@ -870,6 +943,7 @@ def run(rank, world_size, args):
train_dl=train_dl, train_dl=train_dl,
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
pl=pl,
params=params, params=params,
) )
@ -895,6 +969,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
pl=pl,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
@ -930,6 +1005,7 @@ def scan_pessimistic_batches_for_oom(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -949,6 +1025,7 @@ def scan_pessimistic_batches_for_oom(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=0.0,