Finalizing the code:

- adding some coderabbit suggestions.
- removing `word_table`, `decoding_graph` from aligner API (unused)
- improved consistency of variable names (confidences)
- updated docstrings
This commit is contained in:
Karel Vesely 2025-09-15 13:58:51 +02:00
parent d5ff66c56d
commit 0fdba34a70

View File

@ -17,8 +17,8 @@
# limitations under the License. # limitations under the License.
""" """
Batch aligning with CTC model (it can be Tranducer + CTC). Batch aligning with a CTC model (it can be Tranducer + CTC).
It works with both causal an non-causal models. It works with both causal and non-causal models.
Streaming is disabled, or simulated by attention masks Streaming is disabled, or simulated by attention masks
(see: --chunk-size --left-context-frames). (see: --chunk-size --left-context-frames).
Whole utterance processed by 1 forward() call. Whole utterance processed by 1 forward() call.
@ -44,9 +44,8 @@ import logging
import math import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path, PurePath from pathlib import Path, PurePath
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Tuple
import k2
import numpy as np import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
@ -129,7 +128,7 @@ def get_parser():
"--res-dir-suffix", "--res-dir-suffix",
type=str, type=str,
default="", default="",
help="Suffix to where alignments are stored", help="Suffix to the directory, where alignments are stored.",
) )
parser.add_argument( parser.add_argument(
@ -144,8 +143,9 @@ def get_parser():
type=str, type=str,
nargs="+", nargs="+",
default=[], default=[],
help="List of tokens to ignore when computing confidence scores " help="List of BPE tokens to ignore when computing confidence scores "
"(e.g., punctuation marks)", "(e.g., punctuation marks). Each token is a separate arg : "
"`--ignore-tokens 'tok1' 'tok2' ...`",
) )
parser.add_argument( parser.add_argument(
@ -169,7 +169,8 @@ def get_parser():
"dataset_manifests", "dataset_manifests",
type=str, type=str,
nargs="+", nargs="+",
help="CutSet manifests to be aligned (CurSet with features and transcripts)", help="CutSet manifests to be aligned (CutSet with features and transcripts). "
"Each CutSet as a separate arg : `manifest1 mainfest2 ...`",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -183,8 +184,6 @@ def align_one_batch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
ignored_tokens: set[int], ignored_tokens: set[int],
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Align one batch and return the result in a dict. The dict has the """Align one batch and return the result in a dict. The dict has the
following format: following format:
@ -208,15 +207,6 @@ def align_one_batch(
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
UNUSED_PART, CAN BE USED LATER FOR ALIGNING TO A DECODING_GRAPH:
word_table [UNUSED]:
The word symbol table.
decoding_graph [UNUSED]:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the alignment result. See above description for the format of Return the alignment result. See above description for the format of
the returned dict. the returned dict.
@ -275,7 +265,7 @@ def align_one_batch(
targets=targets[ii, : target_lengths[ii]].unsqueeze(dim=0), targets=targets[ii, : target_lengths[ii]].unsqueeze(dim=0),
input_lengths=encoder_out_lens[ii].unsqueeze(dim=0), input_lengths=encoder_out_lens[ii].unsqueeze(dim=0),
target_lengths=target_lengths[ii].unsqueeze(dim=0), target_lengths=target_lengths[ii].unsqueeze(dim=0),
blank=0, blank=params.blank_id,
) )
# per-token time, score # per-token time, score
@ -300,27 +290,27 @@ def align_one_batch(
nonblank_q10 = float(torch.quantile(nonblank_scores, 0.10)) nonblank_q10 = float(torch.quantile(nonblank_scores, 0.10))
nonblank_q20 = float(torch.quantile(nonblank_scores, 0.20)) nonblank_q20 = float(torch.quantile(nonblank_scores, 0.20))
nonblank_q30 = float(torch.quantile(nonblank_scores, 0.30)) nonblank_q30 = float(torch.quantile(nonblank_scores, 0.30))
nonblank_mean = float(nonblank_scores.mean()) mean_frame_conf = float(nonblank_scores.mean())
else: else:
nonblank_min = -1.0 nonblank_min = -1.0
nonblank_q05 = -1.0 nonblank_q05 = -1.0
nonblank_q10 = -1.0 nonblank_q10 = -1.0
nonblank_q20 = -1.0 nonblank_q20 = -1.0
nonblank_q30 = -1.0 nonblank_q30 = -1.0
nonblank_mean = -1.0 mean_frame_conf = -1.0
if num_scores > 0: if num_scores > 0:
confidence = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4 q0_20_conf = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4
else: else:
confidence = 1.0 # default score for short utts q0_20_conf = 1.0 # default, no frames
hyps.append( hyps.append(
{ {
"token_spans": token_spans, "token_spans": token_spans,
"mean_token_conf": mean_token_conf, "mean_token_conf": mean_token_conf,
"confidence": confidence, "q0_20_conf": q0_20_conf,
"num_scores": num_scores, "num_scores": num_scores,
"nonblank_mean": nonblank_mean, "mean_frame_conf": mean_frame_conf,
"nonblank_min": nonblank_min, "nonblank_min": nonblank_min,
"nonblank_q05": nonblank_q05, "nonblank_q05": nonblank_q05,
"nonblank_q10": nonblank_q10, "nonblank_q10": nonblank_q10,
@ -337,8 +327,6 @@ def align_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -351,18 +339,11 @@ def align_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key is "ctc_align" (alignment method).
is used, or it may be "beam_7" if beam size of 7 is used. Its value is a list of tuples. Each tuple is ternary, and it holds
Its value is a list of tuples. Each tuple contains two elements: the a) utterance_key, b) reference transcript and c) dictionary
The first is the reference transcript, and the second is the with alignment results (token spans, confidences, etc).
predicted result.
""" """
num_cuts = 0 num_cuts = 0
@ -387,8 +368,6 @@ def align_dataset(
model=model, model=model,
sp=sp, sp=sp,
ignored_tokens=ignored_tokens_ints, ignored_tokens=ignored_tokens_ints,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch, batch=batch,
) )
@ -408,6 +387,7 @@ def align_dataset(
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results return results
@ -453,21 +433,25 @@ def save_alignment_output(
"(nonblank_min,q05,q10,q20,q30) (num_scores,num_tokens)", "(nonblank_min,q05,q10,q20,q30) (num_scores,num_tokens)",
file=fd, file=fd,
) # header ) # header
for key, ref_text, ali in results:
for utterance_key, ref_text, ali in results:
mean_token_conf = ali["mean_token_conf"] mean_token_conf = ali["mean_token_conf"]
mean_frame_conf = ali["nonblank_mean"] mean_frame_conf = ali["mean_frame_conf"]
q0_20_conf = ali["confidence"] q0_20_conf = ali["q0_20_conf"]
min_ = ali["nonblank_min"] min_ = ali["nonblank_min"]
q05 = ali["nonblank_q05"] q05 = ali["nonblank_q05"]
q10 = ali["nonblank_q10"] q10 = ali["nonblank_q10"]
q20 = ali["nonblank_q20"] q20 = ali["nonblank_q20"]
q30 = ali["nonblank_q30"] q30 = ali["nonblank_q30"]
num_scores = ali[ num_scores = ali[
"num_scores" "num_scores"
] # scores used to compute `mean_frame_conf` ] # scores used to compute `mean_frame_conf`
num_tokens = len(ali["token_spans"]) # tokens in ref transcript num_tokens = len(ali["token_spans"]) # tokens in ref transcript
print( print(
f"{key} {mean_token_conf:.4f} {mean_frame_conf:.4f} " f"{utterance_key} {mean_token_conf:.4f} {mean_frame_conf:.4f} "
f"{q0_20_conf:.4f} " f"{q0_20_conf:.4f} "
f"({min_:.4f},{q05:.4f},{q10:.4f},{q20:.4f},{q30:.4f}) " f"({min_:.4f},{q05:.4f},{q10:.4f},{q20:.4f},{q30:.4f}) "
f"({num_scores},{num_tokens})", f"({num_scores},{num_tokens})",
@ -530,7 +514,7 @@ def main():
# <blk> and <unk> are defined in local/train_bpe_model.py # <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>") # unknown character, not an OOV
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -645,8 +629,6 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=None,
decoding_graph=None,
) )
save_alignment_output( save_alignment_output(