From 77357ebb06de6f3e6f4e7051774e0d0540319d1d Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Mon, 8 Sep 2025 17:31:49 +0200 Subject: [PATCH] zipformer/ctc_align.py - tool for forced-alignment with CTC model - provides timeline, computes per-token and per-utterance acoustic confidences - based on torchaudio `forced_align()` - confidences are computed in several ways other modifications: - LibriSpeechAsrDataModel extended with `::load_manifest()` to allow passing-in cutset from CLI. - update @custom_fwd @custom_bwd in scaling.py - streaming_decode.py update errs/recogs/log filenames '-' <-> '_' --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 8 + egs/librispeech/ASR/zipformer/ctc_align.py | 661 ++++++++++++++++++ egs/librispeech/ASR/zipformer/scaling.py | 10 +- .../ASR/zipformer/streaming_decode.py | 22 +- 4 files changed, 685 insertions(+), 16 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer/ctc_align.py diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 283252a46..d2f6db833 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -402,6 +402,14 @@ class LibriSpeechAsrDataModule: ) return test_dl + @lru_cache() + def load_manifest(self, manifest_filename: str) -> CutSet: + """ + Load the 'manifest' specified by an argument. + """ + logging.info(f"About to get '{manifest_filename}' cuts") + return load_manifest_lazy(manifest_filename) + @lru_cache() def train_clean_5_cuts(self) -> CutSet: logging.info("mini_librispeech: About to get train-clean-5 cuts") diff --git a/egs/librispeech/ASR/zipformer/ctc_align.py b/egs/librispeech/ASR/zipformer/ctc_align.py new file mode 100755 index 000000000..b68d9d589 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/ctc_align.py @@ -0,0 +1,661 @@ +#!/usr/bin/env python3 +# +# Copyright 2025 Brno University of Technology (Author: Karel Vesely) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Batch aligning with CTC model (it can be Tranducer + CTC). +It works with both causal an non-causal models. +Streaming is disabled, or simulated by attention masks +(see: --chunk-size --left-context-frames). +Whole utterance processed by 1 forward() call. + +Note: model averaging is present. With `--epoch 10 --avg 3`, +the epochs 8-10 are taken for averaging. Model averaging +is smoothing the CTC posteriors to some extent. + +Usage: +(1) torchaudio forced_align() +./zipformer/ctc_align.py \ + --epoch 10 \ + --avg 3 \ + --exp-dir ./zipformer/exp \ + --max-duration 300 \ + --decoding-method ctc_align + +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path, PurePath +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule as AsrDataModule +from lhotse import set_caching_enabled +from torchaudio.functional import ( + forced_align, + merge_tokens, +) +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--res-dir-suffix", + type=str, + default="", + help="Suffix to where alignments are stored", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--ignored-tokens", + type=str, + nargs="+", + default=[], + help="", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc_align", + choices=[ + "ctc_align", + ], + help=""" Decoding method for doing the forced alignment.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "dataset_manifests", + type=str, + nargs="+", + help="""Manifests of test-sets to be evaluated""", + ) + + add_model_arguments(parser) + + return parser + + +def align_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + ignored_tokens: set[int], + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Align one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for alignment. + For now, just "ctc_alignment" is used. + - value: It contains the alignment result: (labels, log_probs). + `len(value)` equals to batch size. `value[i]` is the alignment + result for the i-th utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + ignored_tokens: + Set of int token-codes to be ignored for calculation of confidence. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + UNUSED_PART, CAN BE USED LATER FOR ALIGNING TO A DECODING_GRAPH: + + word_table [UNUSED]: + The word symbol table. + decoding_graph [UNUSED]: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + + Returns: + Return the alignment result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + batch_size = feature.shape[0] + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + hyps = [] + + # tokenize the transcripts: + text_encoded = sp.encode(supervisions["text"]) + + # lengths + num_tokens = [len(te) for te in text_encoded] + max_tokens = max(num_tokens) + + # convert to padded np.array: + targets = np.array( + [ + np.pad(seq, (0, max_tokens - len(seq)), "constant", constant_values=-1) + for seq in text_encoded + ] + ) + + # convert to tensor: + targets = torch.tensor(targets, dtype=torch.int32, device=device) + target_lengths = torch.tensor(num_tokens, dtype=torch.int32, device=device) + + # torchaudio2.4.0+ + # The batch dimension for log_probs must be 1 at the current version: + # https://github.com/pytorch/audio/blob/main/src/libtorchaudio/forced_align/gpu/compute.cu#L277 + for ii in range(batch_size): + labels, log_probs = forced_align( + log_probs=ctc_output[ii, : encoder_out_lens[ii]].unsqueeze(dim=0), + targets=targets[ii, : target_lengths[ii]].unsqueeze(dim=0), + input_lengths=encoder_out_lens[ii].unsqueeze(dim=0), + target_lengths=target_lengths[ii].unsqueeze(dim=0), + blank=0, + ) + + # per-token time, score + token_spans = merge_tokens(labels[0], log_probs[0].exp()) + # int -> token + for s in token_spans: + s.token = sp.id_to_piece(s.token) + # mean conf. from the per-token scores + mean_token_conf = np.mean([token_span.score for token_span in token_spans]) + + # confidences + ignore_mask = labels == 0 + for tok in ignored_tokens: + ignore_mask += labels == tok + + nonblank_scores = log_probs[~ignore_mask].exp() + num_scores = nonblank_scores.shape[0] + + if num_scores > 0: + nonblank_min = float(nonblank_scores.min()) + nonblank_q05 = float(torch.quantile(nonblank_scores, 0.05)) + nonblank_q10 = float(torch.quantile(nonblank_scores, 0.10)) + nonblank_q20 = float(torch.quantile(nonblank_scores, 0.20)) + nonblank_q30 = float(torch.quantile(nonblank_scores, 0.30)) + nonblank_mean = float(nonblank_scores.mean()) + else: + nonblank_min = -1.0 + nonblank_q05 = -1.0 + nonblank_q10 = -1.0 + nonblank_q20 = -1.0 + nonblank_q30 = -1.0 + nonblank_mean = -1.0 + + if num_scores > 0: + confidence = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4 + else: + confidence = 1.0 # default score for short utts + + hyps.append( + { + "token_spans": token_spans, + "mean_token_conf": mean_token_conf, + "confidence": confidence, + "num_scores": num_scores, + "nonblank_mean": nonblank_mean, + "nonblank_min": nonblank_min, + "nonblank_q05": nonblank_q05, + "nonblank_q10": nonblank_q10, + "nonblank_q20": nonblank_q20, + "nonblank_q30": nonblank_q30, + } + ) + + return {"ctc_align": hyps} + + +def align_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + ignored_tokens = params.ignored_tokens + ["", ""] + ignored_tokens_ints = [sp.piece_to_id(token) for token in ignored_tokens] + + logging.info(f"ignored tokens {ignored_tokens}") + logging.info(f"ignored int codes {ignored_tokens_ints}") + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = align_one_batch( + params=params, + model=model, + sp=sp, + ignored_tokens=ignored_tokens_ints, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, alignments in hyps_dict.items(): + this_batch = [] + assert len(alignments) == len(texts) + for cut_id, alignments, ref_text in zip(cut_ids, alignments, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, alignments)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + log_interval = 100 + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_alignment_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save the token alignments and per-utterance confidences. + """ + + for key, results in results_dict.items(): + + alignments_filename = params.res_dir / f"alignments-{test_set_name}.txt" + + time_step = 0.04 + + with open(alignments_filename, "w", encoding="utf8") as fd: + for key, ref_text, ali in results: + for token_span in ali["token_spans"]: + + t_beg = token_span.start * time_step + t_end = token_span.end * time_step + t_dur = t_end - t_beg + token = token_span.token + score = token_span.score + + # CTM format : (wav_name, ch, t_beg, t_dur, token, score) + print( + f"{key} A {t_beg:.2f} {t_dur:.2f} {token} {score:.6f}", file=fd + ) + + logging.info(f"The alignments are stored in `{alignments_filename}`") + + # --------------------------- + + confidences_filename = params.res_dir / f"confidences-{test_set_name}.txt" + + with open(confidences_filename, "w", encoding="utf8") as fd: + print( + "utterance_key mean_token_conf mean_frame_conf q0-20_conf " + "(nonblank_min,q05,q10,q20,q30) (num_scores,num_tokens)", + file=fd, + ) # header + for key, ref_text, ali in results: + mean_token_conf = ali["mean_token_conf"] + mean_frame_conf = ali["nonblank_mean"] + q0_20_conf = ali["confidence"] + min_ = ali["nonblank_min"] + q05 = ali["nonblank_q05"] + q10 = ali["nonblank_q10"] + q20 = ali["nonblank_q20"] + q30 = ali["nonblank_q30"] + num_scores = ali[ + "num_scores" + ] # scores used to compute `mean_frame_conf` + num_tokens = len(ali["token_spans"]) # tokens in ref transcript + print( + f"{key} {mean_token_conf:.4f} {mean_frame_conf:.4f} " + f"{q0_20_conf:.4f} " + f"({min_:.4f},{q05:.4f},{q10:.4f},{q20:.4f},{q30:.4f}) " + f"({num_scores},{num_tokens})", + file=fd, + ) + + logging.info(f"The confidences are stored in `{confidences_filename}`") + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ("ctc_align",) + assert params.enable_spec_aug is False + assert params.use_ctc is True + + params.res_dir = params.exp_dir / (params.decoding_method + params.res_dir_suffix) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + params.suffix += f"_{params.decoding_method}" + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-align-{params.suffix}") + logging.info("Forced-alignment started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + + # create array of dataloaders (one per test-set) + testset_labels = [] + testset_dataloaders = [] + for testset_manifest in args.dataset_manifests: + label = PurePath(testset_manifest).name # basename + label = label.replace(".jsonl.gz", "") + + test_cuts = asr_datamodule.load_manifest(testset_manifest) + test_dataloader = asr_datamodule.test_dataloaders(test_cuts) + + testset_labels.append(label) + testset_dataloaders.append(test_dataloader) + + # align + for test_set, test_dl in zip(testset_labels, testset_dataloaders): + results_dict = align_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=None, + decoding_graph=None, + ) + + save_alignment_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 22aa1b1ca..5994f01bf 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -24,7 +24,7 @@ import k2 import torch import torch.nn as nn from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd +from torch.amp import custom_bwd, custom_fwd from icefall.utils import torch_autocast @@ -1306,7 +1306,7 @@ class MulForDropout3(torch.autograd.Function): # returns (x * y * alpha) where alpha is a float and y doesn't require # grad and is zero-or-one. @staticmethod - @custom_fwd + @custom_fwd(device_type='cuda') def forward(ctx, x, y, alpha): assert not y.requires_grad ans = x * y * alpha @@ -1315,7 +1315,7 @@ class MulForDropout3(torch.autograd.Function): return ans @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, ans_grad): (ans,) = ctx.saved_tensors x_grad = ctx.alpha * ans_grad * (ans != 0) @@ -1512,7 +1512,7 @@ def SwooshRForward(x: Tensor): class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type='cuda') def forward( ctx, x: Tensor, @@ -1551,7 +1551,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): return x @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, ans_grad: Tensor): saved = ctx.saved_tensors (x, weight, bias, dropout_mask) = saved diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index ebcafbf87..60e6d0fa8 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -641,15 +641,15 @@ def decode_dataset( del decode_streams[i] if params.decoding_method == "greedy_search": - key = "greedy_search" + key = "greedy-search" elif params.decoding_method == "fast_beam_search": key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" + f"beam-{params.beam}_" + f"max-contexts-{params.max_contexts}_" + f"max-states-{params.max_states}" ) elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" + key = f"num-active-paths-{params.num_active_paths}" else: raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -665,7 +665,7 @@ def save_asr_output( """ for key, results in results_dict.items(): recogs_filename = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}_{key}_{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recogs_filename, texts=results) @@ -685,11 +685,11 @@ def save_wer_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}_{key}_{params.suffix}.txt" ) with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - fd, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}_{key}", results, enable_log=True ) test_set_wers[key] = wer @@ -698,7 +698,7 @@ def save_wer_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) wer_filename = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary_{test_set_name}_{key}_{params.suffix}.txt" ) with open(wer_filename, "w", encoding="utf8") as fd: print("settings\tWER", file=fd) @@ -729,9 +729,9 @@ def main(): params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" assert params.causal, params.causal assert "," not in params.chunk_size, "chunk_size should be one value in decoding."