mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 05:55:26 +00:00
Finalizing the code:
- adding some coderabbit suggestions. - removing `word_table`, `decoding_graph` from aligner API (unused) - improved consistency of variable names (confidences) - updated docstrings
This commit is contained in:
parent
d5ff66c56d
commit
0fdba34a70
@ -17,8 +17,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Batch aligning with CTC model (it can be Tranducer + CTC).
|
||||
It works with both causal an non-causal models.
|
||||
Batch aligning with a CTC model (it can be Tranducer + CTC).
|
||||
It works with both causal and non-causal models.
|
||||
Streaming is disabled, or simulated by attention masks
|
||||
(see: --chunk-size --left-context-frames).
|
||||
Whole utterance processed by 1 forward() call.
|
||||
@ -44,9 +44,8 @@ import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -129,7 +128,7 @@ def get_parser():
|
||||
"--res-dir-suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Suffix to where alignments are stored",
|
||||
help="Suffix to the directory, where alignments are stored.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -144,8 +143,9 @@ def get_parser():
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="List of tokens to ignore when computing confidence scores "
|
||||
"(e.g., punctuation marks)",
|
||||
help="List of BPE tokens to ignore when computing confidence scores "
|
||||
"(e.g., punctuation marks). Each token is a separate arg : "
|
||||
"`--ignore-tokens 'tok1' 'tok2' ...`",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -169,7 +169,8 @@ def get_parser():
|
||||
"dataset_manifests",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="CutSet manifests to be aligned (CurSet with features and transcripts)",
|
||||
help="CutSet manifests to be aligned (CutSet with features and transcripts). "
|
||||
"Each CutSet as a separate arg : `manifest1 mainfest2 ...`",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -183,8 +184,6 @@ def align_one_batch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
ignored_tokens: set[int],
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Align one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -208,15 +207,6 @@ def align_one_batch(
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
|
||||
UNUSED_PART, CAN BE USED LATER FOR ALIGNING TO A DECODING_GRAPH:
|
||||
|
||||
word_table [UNUSED]:
|
||||
The word symbol table.
|
||||
decoding_graph [UNUSED]:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
|
||||
Returns:
|
||||
Return the alignment result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -275,7 +265,7 @@ def align_one_batch(
|
||||
targets=targets[ii, : target_lengths[ii]].unsqueeze(dim=0),
|
||||
input_lengths=encoder_out_lens[ii].unsqueeze(dim=0),
|
||||
target_lengths=target_lengths[ii].unsqueeze(dim=0),
|
||||
blank=0,
|
||||
blank=params.blank_id,
|
||||
)
|
||||
|
||||
# per-token time, score
|
||||
@ -300,27 +290,27 @@ def align_one_batch(
|
||||
nonblank_q10 = float(torch.quantile(nonblank_scores, 0.10))
|
||||
nonblank_q20 = float(torch.quantile(nonblank_scores, 0.20))
|
||||
nonblank_q30 = float(torch.quantile(nonblank_scores, 0.30))
|
||||
nonblank_mean = float(nonblank_scores.mean())
|
||||
mean_frame_conf = float(nonblank_scores.mean())
|
||||
else:
|
||||
nonblank_min = -1.0
|
||||
nonblank_q05 = -1.0
|
||||
nonblank_q10 = -1.0
|
||||
nonblank_q20 = -1.0
|
||||
nonblank_q30 = -1.0
|
||||
nonblank_mean = -1.0
|
||||
mean_frame_conf = -1.0
|
||||
|
||||
if num_scores > 0:
|
||||
confidence = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4
|
||||
q0_20_conf = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4
|
||||
else:
|
||||
confidence = 1.0 # default score for short utts
|
||||
q0_20_conf = 1.0 # default, no frames
|
||||
|
||||
hyps.append(
|
||||
{
|
||||
"token_spans": token_spans,
|
||||
"mean_token_conf": mean_token_conf,
|
||||
"confidence": confidence,
|
||||
"q0_20_conf": q0_20_conf,
|
||||
"num_scores": num_scores,
|
||||
"nonblank_mean": nonblank_mean,
|
||||
"mean_frame_conf": mean_frame_conf,
|
||||
"nonblank_min": nonblank_min,
|
||||
"nonblank_q05": nonblank_q05,
|
||||
"nonblank_q10": nonblank_q10,
|
||||
@ -337,8 +327,6 @@ def align_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -351,18 +339,11 @@ def align_dataset(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
Return a dict, whose key is "ctc_align" (alignment method).
|
||||
Its value is a list of tuples. Each tuple is ternary, and it holds
|
||||
the a) utterance_key, b) reference transcript and c) dictionary
|
||||
with alignment results (token spans, confidences, etc).
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
@ -387,8 +368,6 @@ def align_dataset(
|
||||
model=model,
|
||||
sp=sp,
|
||||
ignored_tokens=ignored_tokens_ints,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -408,6 +387,7 @@ def align_dataset(
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -453,21 +433,25 @@ def save_alignment_output(
|
||||
"(nonblank_min,q05,q10,q20,q30) (num_scores,num_tokens)",
|
||||
file=fd,
|
||||
) # header
|
||||
for key, ref_text, ali in results:
|
||||
|
||||
for utterance_key, ref_text, ali in results:
|
||||
mean_token_conf = ali["mean_token_conf"]
|
||||
mean_frame_conf = ali["nonblank_mean"]
|
||||
q0_20_conf = ali["confidence"]
|
||||
mean_frame_conf = ali["mean_frame_conf"]
|
||||
q0_20_conf = ali["q0_20_conf"]
|
||||
min_ = ali["nonblank_min"]
|
||||
q05 = ali["nonblank_q05"]
|
||||
q10 = ali["nonblank_q10"]
|
||||
q20 = ali["nonblank_q20"]
|
||||
q30 = ali["nonblank_q30"]
|
||||
|
||||
num_scores = ali[
|
||||
"num_scores"
|
||||
] # scores used to compute `mean_frame_conf`
|
||||
|
||||
num_tokens = len(ali["token_spans"]) # tokens in ref transcript
|
||||
|
||||
print(
|
||||
f"{key} {mean_token_conf:.4f} {mean_frame_conf:.4f} "
|
||||
f"{utterance_key} {mean_token_conf:.4f} {mean_frame_conf:.4f} "
|
||||
f"{q0_20_conf:.4f} "
|
||||
f"({min_:.4f},{q05:.4f},{q10:.4f},{q20:.4f},{q30:.4f}) "
|
||||
f"({num_scores},{num_tokens})",
|
||||
@ -530,7 +514,7 @@ def main():
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>") # unknown character, not an OOV
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
@ -645,8 +629,6 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=None,
|
||||
decoding_graph=None,
|
||||
)
|
||||
|
||||
save_alignment_output(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user