diff --git a/egs/librispeech/ASR/zipformer/ctc_align.py b/egs/librispeech/ASR/zipformer/ctc_align.py index e3dfe1fa9..fff05146f 100755 --- a/egs/librispeech/ASR/zipformer/ctc_align.py +++ b/egs/librispeech/ASR/zipformer/ctc_align.py @@ -17,8 +17,8 @@ # limitations under the License. """ -Batch aligning with CTC model (it can be Tranducer + CTC). -It works with both causal an non-causal models. +Batch aligning with a CTC model (it can be Tranducer + CTC). +It works with both causal and non-causal models. Streaming is disabled, or simulated by attention masks (see: --chunk-size --left-context-frames). Whole utterance processed by 1 forward() call. @@ -44,9 +44,8 @@ import logging import math from collections import defaultdict from pathlib import Path, PurePath -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple -import k2 import numpy as np import sentencepiece as spm import torch @@ -129,7 +128,7 @@ def get_parser(): "--res-dir-suffix", type=str, default="", - help="Suffix to where alignments are stored", + help="Suffix to the directory, where alignments are stored.", ) parser.add_argument( @@ -144,8 +143,9 @@ def get_parser(): type=str, nargs="+", default=[], - help="List of tokens to ignore when computing confidence scores " - "(e.g., punctuation marks)", + help="List of BPE tokens to ignore when computing confidence scores " + "(e.g., punctuation marks). Each token is a separate arg : " + "`--ignore-tokens 'tok1' 'tok2' ...`", ) parser.add_argument( @@ -169,7 +169,8 @@ def get_parser(): "dataset_manifests", type=str, nargs="+", - help="CutSet manifests to be aligned (CurSet with features and transcripts)", + help="CutSet manifests to be aligned (CutSet with features and transcripts). " + "Each CutSet as a separate arg : `manifest1 mainfest2 ...`", ) add_model_arguments(parser) @@ -183,8 +184,6 @@ def align_one_batch( sp: spm.SentencePieceProcessor, ignored_tokens: set[int], batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Align one batch and return the result in a dict. The dict has the following format: @@ -208,15 +207,6 @@ def align_one_batch( `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - UNUSED_PART, CAN BE USED LATER FOR ALIGNING TO A DECODING_GRAPH: - - word_table [UNUSED]: - The word symbol table. - decoding_graph [UNUSED]: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: Return the alignment result. See above description for the format of the returned dict. @@ -275,7 +265,7 @@ def align_one_batch( targets=targets[ii, : target_lengths[ii]].unsqueeze(dim=0), input_lengths=encoder_out_lens[ii].unsqueeze(dim=0), target_lengths=target_lengths[ii].unsqueeze(dim=0), - blank=0, + blank=params.blank_id, ) # per-token time, score @@ -300,27 +290,27 @@ def align_one_batch( nonblank_q10 = float(torch.quantile(nonblank_scores, 0.10)) nonblank_q20 = float(torch.quantile(nonblank_scores, 0.20)) nonblank_q30 = float(torch.quantile(nonblank_scores, 0.30)) - nonblank_mean = float(nonblank_scores.mean()) + mean_frame_conf = float(nonblank_scores.mean()) else: nonblank_min = -1.0 nonblank_q05 = -1.0 nonblank_q10 = -1.0 nonblank_q20 = -1.0 nonblank_q30 = -1.0 - nonblank_mean = -1.0 + mean_frame_conf = -1.0 if num_scores > 0: - confidence = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4 + q0_20_conf = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4 else: - confidence = 1.0 # default score for short utts + q0_20_conf = 1.0 # default, no frames hyps.append( { "token_spans": token_spans, "mean_token_conf": mean_token_conf, - "confidence": confidence, + "q0_20_conf": q0_20_conf, "num_scores": num_scores, - "nonblank_mean": nonblank_mean, + "mean_frame_conf": mean_frame_conf, "nonblank_min": nonblank_min, "nonblank_q05": nonblank_q05, "nonblank_q10": nonblank_q10, @@ -337,8 +327,6 @@ def align_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -351,18 +339,11 @@ def align_dataset( The neural model. sp: The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + Return a dict, whose key is "ctc_align" (alignment method). + Its value is a list of tuples. Each tuple is ternary, and it holds + the a) utterance_key, b) reference transcript and c) dictionary + with alignment results (token spans, confidences, etc). """ num_cuts = 0 @@ -387,8 +368,6 @@ def align_dataset( model=model, sp=sp, ignored_tokens=ignored_tokens_ints, - decoding_graph=decoding_graph, - word_table=word_table, batch=batch, ) @@ -408,6 +387,7 @@ def align_dataset( batch_str = f"{batch_idx}/{num_batches}" logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results @@ -453,21 +433,25 @@ def save_alignment_output( "(nonblank_min,q05,q10,q20,q30) (num_scores,num_tokens)", file=fd, ) # header - for key, ref_text, ali in results: + + for utterance_key, ref_text, ali in results: mean_token_conf = ali["mean_token_conf"] - mean_frame_conf = ali["nonblank_mean"] - q0_20_conf = ali["confidence"] + mean_frame_conf = ali["mean_frame_conf"] + q0_20_conf = ali["q0_20_conf"] min_ = ali["nonblank_min"] q05 = ali["nonblank_q05"] q10 = ali["nonblank_q10"] q20 = ali["nonblank_q20"] q30 = ali["nonblank_q30"] + num_scores = ali[ "num_scores" ] # scores used to compute `mean_frame_conf` + num_tokens = len(ali["token_spans"]) # tokens in ref transcript + print( - f"{key} {mean_token_conf:.4f} {mean_frame_conf:.4f} " + f"{utterance_key} {mean_token_conf:.4f} {mean_frame_conf:.4f} " f"{q0_20_conf:.4f} " f"({min_:.4f},{q05:.4f},{q10:.4f},{q20:.4f},{q30:.4f}) " f"({num_scores},{num_tokens})", @@ -530,7 +514,7 @@ def main(): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") # unknown character, not an OOV params.vocab_size = sp.get_piece_size() logging.info(params) @@ -645,8 +629,6 @@ def main(): params=params, model=model, sp=sp, - word_table=None, - decoding_graph=None, ) save_alignment_output(