Merge 4858e2b0367a7ebaf20641743d91a745777aca63 into abd9437e6d5419a497707748eb935e50976c3b7b

This commit is contained in:
Yifan Yang 2025-06-27 11:33:11 +00:00 committed by GitHub
commit 2c7dcd65f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 513 additions and 117 deletions

View File

@ -388,16 +388,14 @@ class GigaSpeechAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info(f"About to get train_{self.args.subset} cuts") logging.info(f"About to get train_{self.args.subset} cuts")
path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
cuts_train = CutSet.from_jsonl_lazy(path) cuts_train = CutSet.from_jsonl_lazy(path)
return cuts_train return cuts_train
@lru_cache() @lru_cache()
def dev_cuts(self) -> CutSet: def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy( cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
if self.args.small_dev: if self.args.small_dev:
return cuts_valid.subset(first=1000) return cuts_valid.subset(first=1000)
else: else:
@ -406,6 +404,4 @@ class GigaSpeechAsrDataModule:
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)

View File

@ -24,8 +24,7 @@ Usage:
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended)
(2) beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
@ -33,7 +32,6 @@ Usage:
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
@ -42,17 +40,60 @@ Usage:
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (one best)
(4) fast beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) fast beam search (nbest with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -69,6 +110,9 @@ import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
@ -83,6 +127,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import UniqLexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -128,7 +173,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=True,
help="Whether to load averaged model. Currently it only supports " help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
@ -146,10 +191,17 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default=None,
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_phone",
help="The lang dir contains word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -159,6 +211,20 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
""",
)
parser.add_argument(
"--metrics",
type=str,
default="WER",
help="""Possible values are:
- WER
- PER
""", """,
) )
@ -174,27 +240,45 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""", Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG or
fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle
""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search""", fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle
""",
) )
parser.add_argument( parser.add_argument(
@ -203,6 +287,7 @@ def get_parser():
default=2, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
@ -211,6 +296,24 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
return parser return parser
@ -229,7 +332,9 @@ def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -249,13 +354,19 @@ def decode_one_batch(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
pl:
The phone lexicon.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle
and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -273,7 +384,10 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -283,6 +397,58 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
if params.decoding_method == "fast_beam_search":
if sp is not None:
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
@ -291,8 +457,12 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): if sp is not None:
hyps.append(hyp.split()) for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -300,8 +470,12 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): if sp is not None:
hyps.append(hyp.split()) for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -325,18 +499,24 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) if sp is not None:
hyps.append(sp.decode(hyp).split())
else:
hyps.append([str(i) for i in hyp])
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif "fast_beam_search" in params.decoding_method:
return { key = f"beam_{params.beam}_"
( key += f"max_contexts_{params.max_contexts}_"
f"beam_{params.beam}_" key += f"max_states_{params.max_states}"
f"max_contexts_{params.max_contexts}_" if "nbest" in params.decoding_method:
f"max_states_{params.max_states}" key += f"_num_paths_{params.num_paths}_"
): hyps key += f"nbest_scale_{params.nbest_scale}"
} if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -346,6 +526,8 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -359,9 +541,15 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
pl:
The phone lexicon.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle
and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -376,29 +564,82 @@ def decode_dataset(
except TypeError: except TypeError:
num_batches = "?" num_batches = "?"
log_interval = 20 if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] if sp is not None:
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, pl=pl,
batch=batch, word_table=word_table,
) decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
else:
if params.metrics == "WER":
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
elif params.metrics == "PER":
texts = batch["supervisions"]["text"]
token_ids = pl.texts_to_token_ids(texts).tolist()
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(token_ids)
for cut_id, hyp_id, ref_token_id in zip(cut_ids, hyps, token_ids):
ref_token_id = [str(i) for i in ref_token_id]
this_batch.append((cut_id, ref_token_id, hyp_id))
results[name].extend(this_batch)
num_cuts += len(texts) num_cuts += len(texts)
@ -414,38 +655,73 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
test_set_wers = dict() if params.metrics == "WER":
for key, results in results_dict.items(): test_set_wers = dict()
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" for key, results in results_dict.items():
results = post_processing(results) recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = post_processing(results)
store_transcripts(filename=recog_path, texts=results) results = sorted(results)
logging.info(f"The transcripts are stored in {recog_path}") store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) elif params.metrics == "PER":
note = "\tbest for {}".format(test_set_name) test_set_pers = dict()
for key, val in test_set_wers: for key, results in results_dict.items():
s += "{}\t{}{}\n".format(key, val, note) recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
note = "" results = post_processing(results)
logging.info(s) results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out PERs, per-phone error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
per = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_pers[key] = per
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_pers = sorted(test_set_pers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"per-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tPER", file=f)
for key, val in test_set_pers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, PER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_pers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad() @torch.no_grad()
@ -458,12 +734,32 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ( if params.bpe_model is not None:
"greedy_search", assert params.decoding_method in (
"beam_search", "greedy_search",
"fast_beam_search", "beam_search",
"modified_beam_search", "fast_beam_search",
) "fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
else:
if params.metrics == "PER":
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
), "Decoding method without L or LG must use PER"
elif params.metrics == "WER":
assert params.decoding_method in (
"fast_beam_search_LG",
"fast_beam_search_LG",
"fast_beam_search_nbest_LG",
), "Decoding method with L or LG must use WER"
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -475,8 +771,13 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -493,13 +794,19 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() if params.bpe_model is not None:
sp.load(params.bpe_model) sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
else:
pl = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(pl.tokens) + 1
sp = None
logging.info(params) logging.info(params)
@ -586,10 +893,24 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if (
params.decoding_method == "fast_beam_search_LG"
or params.decoding_method == "fast_beam_search_nbest_LG"
):
word_table = pl.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -613,6 +934,8 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang # Wei Kang,
# Mingshuang Luo) # Mingshuang Luo,
# Yifan Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -19,13 +20,15 @@
""" """
Usage: Usage:
(1) bpe
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \ ./pruned_transducer_stateless2/train.py \
--world-size 8 \ --world-size 8 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--max-duration 120 --max-duration 120
# For mix precision training: # For mix precision training:
@ -33,11 +36,37 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \ ./pruned_transducer_stateless2/train.py \
--world-size 8 \ --world-size 8 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--use_fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--max-duration 200 --max-duration 200
(2) phone
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--lang-type phone \
--context-size 4 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--lang-type phone \
--context-size 4 \
--max-duration 750
""" """
@ -77,6 +106,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import UniqLexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -119,8 +149,8 @@ def get_parser():
"--start-epoch", "--start-epoch",
type=int, type=int,
default=1, default=1,
help="""Resume training from this epoch. help="""Resume training from from this epoch.
If larger than 1, it will load checkpoint from If it is large than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt exp-dir/epoch-{start_epoch-1}.pt
""", """,
) )
@ -144,6 +174,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--lang-type",
type=str,
default="bpe",
help="Either bpe or phone",
)
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
@ -151,6 +188,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_phone",
help="the lang dir contains lexicon",
)
parser.add_argument( parser.add_argument(
"--initial-lr", "--initial-lr",
type=float, type=float,
@ -231,7 +275,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=8000, default=20000,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -244,7 +288,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--keep-last-k", "--keep-last-k",
type=int, type=int,
default=30, default=20,
help="""Only keep this number of checkpoints on disk. help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`. in the exp-dir with filenames `checkpoint-xxx.pt`.
@ -261,7 +305,7 @@ def get_parser():
in which each floating-point parameter is the average of all the in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average, parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) + we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""", """,
) )
@ -519,6 +563,7 @@ def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float = 1.0, warmup: float = 1.0,
@ -551,8 +596,12 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) if sp is not None:
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
else:
y = pl.texts_to_token_ids(texts).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss = model(
@ -592,6 +641,7 @@ def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> MetricsTracker: ) -> MetricsTracker:
@ -605,6 +655,7 @@ def compute_validation_loss(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
batch=batch, batch=batch,
is_training=False, is_training=False,
) )
@ -628,6 +679,7 @@ def train_one_epoch(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
@ -679,8 +731,8 @@ def train_one_epoch(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
sp=sp, sp=sp,
pl=pl,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step), warmup=(params.batch_idx_train / params.model_warm_step),
@ -699,6 +751,17 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 30: if params.print_diagnostics and batch_idx == 30:
return return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if ( if (
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
@ -746,6 +809,7 @@ def train_one_epoch(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
@ -795,12 +859,21 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() if params.lang_type == "bpe":
sp.load(params.bpe_model) logging.info(f"Using bpe model")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
pl = None
elif params.lang_type == "phone":
logging.info(f"Using phone lexion")
pl = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(pl.tokens) + 1
sp = None
logging.info(params) logging.info(params)
@ -870,6 +943,7 @@ def run(rank, world_size, args):
train_dl=train_dl, train_dl=train_dl,
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
pl=pl,
params=params, params=params,
) )
@ -895,6 +969,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
pl=pl,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
@ -930,6 +1005,7 @@ def scan_pessimistic_batches_for_oom(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -949,6 +1025,7 @@ def scan_pessimistic_batches_for_oom(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
pl=pl,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=0.0,