Finalizing the code:

- adding some coderabbit suggestions.
- removing `word_table`, `decoding_graph` from aligner API (unused)
- improved consistency of variable names (confidences)
- updated docstrings
This commit is contained in:
Karel Vesely 2025-09-15 13:58:51 +02:00
parent d5ff66c56d
commit 0fdba34a70

View File

@ -17,8 +17,8 @@
# limitations under the License.
"""
Batch aligning with CTC model (it can be Tranducer + CTC).
It works with both causal an non-causal models.
Batch aligning with a CTC model (it can be Tranducer + CTC).
It works with both causal and non-causal models.
Streaming is disabled, or simulated by attention masks
(see: --chunk-size --left-context-frames).
Whole utterance processed by 1 forward() call.
@ -44,9 +44,8 @@ import logging
import math
from collections import defaultdict
from pathlib import Path, PurePath
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
@ -129,7 +128,7 @@ def get_parser():
"--res-dir-suffix",
type=str,
default="",
help="Suffix to where alignments are stored",
help="Suffix to the directory, where alignments are stored.",
)
parser.add_argument(
@ -144,8 +143,9 @@ def get_parser():
type=str,
nargs="+",
default=[],
help="List of tokens to ignore when computing confidence scores "
"(e.g., punctuation marks)",
help="List of BPE tokens to ignore when computing confidence scores "
"(e.g., punctuation marks). Each token is a separate arg : "
"`--ignore-tokens 'tok1' 'tok2' ...`",
)
parser.add_argument(
@ -169,7 +169,8 @@ def get_parser():
"dataset_manifests",
type=str,
nargs="+",
help="CutSet manifests to be aligned (CurSet with features and transcripts)",
help="CutSet manifests to be aligned (CutSet with features and transcripts). "
"Each CutSet as a separate arg : `manifest1 mainfest2 ...`",
)
add_model_arguments(parser)
@ -183,8 +184,6 @@ def align_one_batch(
sp: spm.SentencePieceProcessor,
ignored_tokens: set[int],
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Align one batch and return the result in a dict. The dict has the
following format:
@ -208,15 +207,6 @@ def align_one_batch(
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
UNUSED_PART, CAN BE USED LATER FOR ALIGNING TO A DECODING_GRAPH:
word_table [UNUSED]:
The word symbol table.
decoding_graph [UNUSED]:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the alignment result. See above description for the format of
the returned dict.
@ -275,7 +265,7 @@ def align_one_batch(
targets=targets[ii, : target_lengths[ii]].unsqueeze(dim=0),
input_lengths=encoder_out_lens[ii].unsqueeze(dim=0),
target_lengths=target_lengths[ii].unsqueeze(dim=0),
blank=0,
blank=params.blank_id,
)
# per-token time, score
@ -300,27 +290,27 @@ def align_one_batch(
nonblank_q10 = float(torch.quantile(nonblank_scores, 0.10))
nonblank_q20 = float(torch.quantile(nonblank_scores, 0.20))
nonblank_q30 = float(torch.quantile(nonblank_scores, 0.30))
nonblank_mean = float(nonblank_scores.mean())
mean_frame_conf = float(nonblank_scores.mean())
else:
nonblank_min = -1.0
nonblank_q05 = -1.0
nonblank_q10 = -1.0
nonblank_q20 = -1.0
nonblank_q30 = -1.0
nonblank_mean = -1.0
mean_frame_conf = -1.0
if num_scores > 0:
confidence = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4
q0_20_conf = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4
else:
confidence = 1.0 # default score for short utts
q0_20_conf = 1.0 # default, no frames
hyps.append(
{
"token_spans": token_spans,
"mean_token_conf": mean_token_conf,
"confidence": confidence,
"q0_20_conf": q0_20_conf,
"num_scores": num_scores,
"nonblank_mean": nonblank_mean,
"mean_frame_conf": mean_frame_conf,
"nonblank_min": nonblank_min,
"nonblank_q05": nonblank_q05,
"nonblank_q10": nonblank_q10,
@ -337,8 +327,6 @@ def align_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -351,18 +339,11 @@ def align_dataset(
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
Return a dict, whose key is "ctc_align" (alignment method).
Its value is a list of tuples. Each tuple is ternary, and it holds
the a) utterance_key, b) reference transcript and c) dictionary
with alignment results (token spans, confidences, etc).
"""
num_cuts = 0
@ -387,8 +368,6 @@ def align_dataset(
model=model,
sp=sp,
ignored_tokens=ignored_tokens_ints,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
@ -408,6 +387,7 @@ def align_dataset(
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -453,21 +433,25 @@ def save_alignment_output(
"(nonblank_min,q05,q10,q20,q30) (num_scores,num_tokens)",
file=fd,
) # header
for key, ref_text, ali in results:
for utterance_key, ref_text, ali in results:
mean_token_conf = ali["mean_token_conf"]
mean_frame_conf = ali["nonblank_mean"]
q0_20_conf = ali["confidence"]
mean_frame_conf = ali["mean_frame_conf"]
q0_20_conf = ali["q0_20_conf"]
min_ = ali["nonblank_min"]
q05 = ali["nonblank_q05"]
q10 = ali["nonblank_q10"]
q20 = ali["nonblank_q20"]
q30 = ali["nonblank_q30"]
num_scores = ali[
"num_scores"
] # scores used to compute `mean_frame_conf`
num_tokens = len(ali["token_spans"]) # tokens in ref transcript
print(
f"{key} {mean_token_conf:.4f} {mean_frame_conf:.4f} "
f"{utterance_key} {mean_token_conf:.4f} {mean_frame_conf:.4f} "
f"{q0_20_conf:.4f} "
f"({min_:.4f},{q05:.4f},{q10:.4f},{q20:.4f},{q30:.4f}) "
f"({num_scores},{num_tokens})",
@ -530,7 +514,7 @@ def main():
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.unk_id = sp.piece_to_id("<unk>") # unknown character, not an OOV
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -645,8 +629,6 @@ def main():
params=params,
model=model,
sp=sp,
word_table=None,
decoding_graph=None,
)
save_alignment_output(