mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
zipformer/ctc_align.py
- tool for forced-alignment with CTC model - provides timeline, computes per-token and per-utterance acoustic confidences - based on torchaudio `forced_align()` - confidences are computed in several ways other modifications: - LibriSpeechAsrDataModel extended with `::load_manifest()` to allow passing-in cutset from CLI. - update @custom_fwd @custom_bwd in scaling.py - streaming_decode.py update errs/recogs/log filenames '-' <-> '_'
This commit is contained in:
parent
0c7ce5256f
commit
77357ebb06
@ -402,6 +402,14 @@ class LibriSpeechAsrDataModule:
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def load_manifest(self, manifest_filename: str) -> CutSet:
|
||||
"""
|
||||
Load the 'manifest' specified by an argument.
|
||||
"""
|
||||
logging.info(f"About to get '{manifest_filename}' cuts")
|
||||
return load_manifest_lazy(manifest_filename)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_5_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||
|
||||
661
egs/librispeech/ASR/zipformer/ctc_align.py
Executable file
661
egs/librispeech/ASR/zipformer/ctc_align.py
Executable file
@ -0,0 +1,661 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2025 Brno University of Technology (Author: Karel Vesely)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Batch aligning with CTC model (it can be Tranducer + CTC).
|
||||
It works with both causal an non-causal models.
|
||||
Streaming is disabled, or simulated by attention masks
|
||||
(see: --chunk-size --left-context-frames).
|
||||
Whole utterance processed by 1 forward() call.
|
||||
|
||||
Note: model averaging is present. With `--epoch 10 --avg 3`,
|
||||
the epochs 8-10 are taken for averaging. Model averaging
|
||||
is smoothing the CTC posteriors to some extent.
|
||||
|
||||
Usage:
|
||||
(1) torchaudio forced_align()
|
||||
./zipformer/ctc_align.py \
|
||||
--epoch 10 \
|
||||
--avg 3 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 300 \
|
||||
--decoding-method ctc_align
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule as AsrDataModule
|
||||
from lhotse import set_caching_enabled
|
||||
from torchaudio.functional import (
|
||||
forced_align,
|
||||
merge_tokens,
|
||||
)
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--res-dir-suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Suffix to where alignments are stored",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ignored-tokens",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="ctc_align",
|
||||
choices=[
|
||||
"ctc_align",
|
||||
],
|
||||
help=""" Decoding method for doing the forced alignment.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"dataset_manifests",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="""Manifests of test-sets to be evaluated""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def align_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
ignored_tokens: set[int],
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Align one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for alignment.
|
||||
For now, just "ctc_alignment" is used.
|
||||
- value: It contains the alignment result: (labels, log_probs).
|
||||
`len(value)` equals to batch size. `value[i]` is the alignment
|
||||
result for the i-th utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
ignored_tokens:
|
||||
Set of int token-codes to be ignored for calculation of confidence.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
|
||||
UNUSED_PART, CAN BE USED LATER FOR ALIGNING TO A DECODING_GRAPH:
|
||||
|
||||
word_table [UNUSED]:
|
||||
The word symbol table.
|
||||
decoding_graph [UNUSED]:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
|
||||
Returns:
|
||||
Return the alignment result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
batch_size = feature.shape[0]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.causal:
|
||||
pad_len = 30
|
||||
feature_lens += pad_len
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, pad_len),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
hyps = []
|
||||
|
||||
# tokenize the transcripts:
|
||||
text_encoded = sp.encode(supervisions["text"])
|
||||
|
||||
# lengths
|
||||
num_tokens = [len(te) for te in text_encoded]
|
||||
max_tokens = max(num_tokens)
|
||||
|
||||
# convert to padded np.array:
|
||||
targets = np.array(
|
||||
[
|
||||
np.pad(seq, (0, max_tokens - len(seq)), "constant", constant_values=-1)
|
||||
for seq in text_encoded
|
||||
]
|
||||
)
|
||||
|
||||
# convert to tensor:
|
||||
targets = torch.tensor(targets, dtype=torch.int32, device=device)
|
||||
target_lengths = torch.tensor(num_tokens, dtype=torch.int32, device=device)
|
||||
|
||||
# torchaudio2.4.0+
|
||||
# The batch dimension for log_probs must be 1 at the current version:
|
||||
# https://github.com/pytorch/audio/blob/main/src/libtorchaudio/forced_align/gpu/compute.cu#L277
|
||||
for ii in range(batch_size):
|
||||
labels, log_probs = forced_align(
|
||||
log_probs=ctc_output[ii, : encoder_out_lens[ii]].unsqueeze(dim=0),
|
||||
targets=targets[ii, : target_lengths[ii]].unsqueeze(dim=0),
|
||||
input_lengths=encoder_out_lens[ii].unsqueeze(dim=0),
|
||||
target_lengths=target_lengths[ii].unsqueeze(dim=0),
|
||||
blank=0,
|
||||
)
|
||||
|
||||
# per-token time, score
|
||||
token_spans = merge_tokens(labels[0], log_probs[0].exp())
|
||||
# int -> token
|
||||
for s in token_spans:
|
||||
s.token = sp.id_to_piece(s.token)
|
||||
# mean conf. from the per-token scores
|
||||
mean_token_conf = np.mean([token_span.score for token_span in token_spans])
|
||||
|
||||
# confidences
|
||||
ignore_mask = labels == 0
|
||||
for tok in ignored_tokens:
|
||||
ignore_mask += labels == tok
|
||||
|
||||
nonblank_scores = log_probs[~ignore_mask].exp()
|
||||
num_scores = nonblank_scores.shape[0]
|
||||
|
||||
if num_scores > 0:
|
||||
nonblank_min = float(nonblank_scores.min())
|
||||
nonblank_q05 = float(torch.quantile(nonblank_scores, 0.05))
|
||||
nonblank_q10 = float(torch.quantile(nonblank_scores, 0.10))
|
||||
nonblank_q20 = float(torch.quantile(nonblank_scores, 0.20))
|
||||
nonblank_q30 = float(torch.quantile(nonblank_scores, 0.30))
|
||||
nonblank_mean = float(nonblank_scores.mean())
|
||||
else:
|
||||
nonblank_min = -1.0
|
||||
nonblank_q05 = -1.0
|
||||
nonblank_q10 = -1.0
|
||||
nonblank_q20 = -1.0
|
||||
nonblank_q30 = -1.0
|
||||
nonblank_mean = -1.0
|
||||
|
||||
if num_scores > 0:
|
||||
confidence = (nonblank_min + nonblank_q05 + nonblank_q10 + nonblank_q20) / 4
|
||||
else:
|
||||
confidence = 1.0 # default score for short utts
|
||||
|
||||
hyps.append(
|
||||
{
|
||||
"token_spans": token_spans,
|
||||
"mean_token_conf": mean_token_conf,
|
||||
"confidence": confidence,
|
||||
"num_scores": num_scores,
|
||||
"nonblank_mean": nonblank_mean,
|
||||
"nonblank_min": nonblank_min,
|
||||
"nonblank_q05": nonblank_q05,
|
||||
"nonblank_q10": nonblank_q10,
|
||||
"nonblank_q20": nonblank_q20,
|
||||
"nonblank_q30": nonblank_q30,
|
||||
}
|
||||
)
|
||||
|
||||
return {"ctc_align": hyps}
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
ignored_tokens = params.ignored_tokens + ["<sos/eos>", "<unk>"]
|
||||
ignored_tokens_ints = [sp.piece_to_id(token) for token in ignored_tokens]
|
||||
|
||||
logging.info(f"ignored tokens {ignored_tokens}")
|
||||
logging.info(f"ignored int codes {ignored_tokens_ints}")
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = align_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
ignored_tokens=ignored_tokens_ints,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, alignments in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(alignments) == len(texts)
|
||||
for cut_id, alignments, ref_text in zip(cut_ids, alignments, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, alignments))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
log_interval = 100
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_alignment_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save the token alignments and per-utterance confidences.
|
||||
"""
|
||||
|
||||
for key, results in results_dict.items():
|
||||
|
||||
alignments_filename = params.res_dir / f"alignments-{test_set_name}.txt"
|
||||
|
||||
time_step = 0.04
|
||||
|
||||
with open(alignments_filename, "w", encoding="utf8") as fd:
|
||||
for key, ref_text, ali in results:
|
||||
for token_span in ali["token_spans"]:
|
||||
|
||||
t_beg = token_span.start * time_step
|
||||
t_end = token_span.end * time_step
|
||||
t_dur = t_end - t_beg
|
||||
token = token_span.token
|
||||
score = token_span.score
|
||||
|
||||
# CTM format : (wav_name, ch, t_beg, t_dur, token, score)
|
||||
print(
|
||||
f"{key} A {t_beg:.2f} {t_dur:.2f} {token} {score:.6f}", file=fd
|
||||
)
|
||||
|
||||
logging.info(f"The alignments are stored in `{alignments_filename}`")
|
||||
|
||||
# ---------------------------
|
||||
|
||||
confidences_filename = params.res_dir / f"confidences-{test_set_name}.txt"
|
||||
|
||||
with open(confidences_filename, "w", encoding="utf8") as fd:
|
||||
print(
|
||||
"utterance_key mean_token_conf mean_frame_conf q0-20_conf "
|
||||
"(nonblank_min,q05,q10,q20,q30) (num_scores,num_tokens)",
|
||||
file=fd,
|
||||
) # header
|
||||
for key, ref_text, ali in results:
|
||||
mean_token_conf = ali["mean_token_conf"]
|
||||
mean_frame_conf = ali["nonblank_mean"]
|
||||
q0_20_conf = ali["confidence"]
|
||||
min_ = ali["nonblank_min"]
|
||||
q05 = ali["nonblank_q05"]
|
||||
q10 = ali["nonblank_q10"]
|
||||
q20 = ali["nonblank_q20"]
|
||||
q30 = ali["nonblank_q30"]
|
||||
num_scores = ali[
|
||||
"num_scores"
|
||||
] # scores used to compute `mean_frame_conf`
|
||||
num_tokens = len(ali["token_spans"]) # tokens in ref transcript
|
||||
print(
|
||||
f"{key} {mean_token_conf:.4f} {mean_frame_conf:.4f} "
|
||||
f"{q0_20_conf:.4f} "
|
||||
f"({min_:.4f},{q05:.4f},{q10:.4f},{q20:.4f},{q30:.4f}) "
|
||||
f"({num_scores},{num_tokens})",
|
||||
file=fd,
|
||||
)
|
||||
|
||||
logging.info(f"The confidences are stored in `{confidences_filename}`")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
assert params.decoding_method in ("ctc_align",)
|
||||
assert params.enable_spec_aug is False
|
||||
assert params.use_ctc is True
|
||||
|
||||
params.res_dir = params.exp_dir / (params.decoding_method + params.res_dir_suffix)
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
params.suffix += f"_{params.decoding_method}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-align-{params.suffix}")
|
||||
logging.info("Forced-alignment started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
asr_datamodule = AsrDataModule(args)
|
||||
|
||||
# create array of dataloaders (one per test-set)
|
||||
testset_labels = []
|
||||
testset_dataloaders = []
|
||||
for testset_manifest in args.dataset_manifests:
|
||||
label = PurePath(testset_manifest).name # basename
|
||||
label = label.replace(".jsonl.gz", "")
|
||||
|
||||
test_cuts = asr_datamodule.load_manifest(testset_manifest)
|
||||
test_dataloader = asr_datamodule.test_dataloaders(test_cuts)
|
||||
|
||||
testset_labels.append(label)
|
||||
testset_dataloaders.append(test_dataloader)
|
||||
|
||||
# align
|
||||
for test_set, test_dl in zip(testset_labels, testset_dataloaders):
|
||||
results_dict = align_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=None,
|
||||
decoding_graph=None,
|
||||
)
|
||||
|
||||
save_alignment_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -24,7 +24,7 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.amp import custom_bwd, custom_fwd
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
@ -1306,7 +1306,7 @@ class MulForDropout3(torch.autograd.Function):
|
||||
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
||||
# grad and is zero-or-one.
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
@custom_fwd(device_type='cuda')
|
||||
def forward(ctx, x, y, alpha):
|
||||
assert not y.requires_grad
|
||||
ans = x * y * alpha
|
||||
@ -1315,7 +1315,7 @@ class MulForDropout3(torch.autograd.Function):
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@custom_bwd(device_type='cuda')
|
||||
def backward(ctx, ans_grad):
|
||||
(ans,) = ctx.saved_tensors
|
||||
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
||||
@ -1512,7 +1512,7 @@ def SwooshRForward(x: Tensor):
|
||||
|
||||
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
@custom_fwd(device_type='cuda')
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
@ -1551,7 +1551,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@custom_bwd(device_type='cuda')
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
saved = ctx.saved_tensors
|
||||
(x, weight, bias, dropout_mask) = saved
|
||||
|
||||
@ -641,15 +641,15 @@ def decode_dataset(
|
||||
del decode_streams[i]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
key = "greedy-search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
f"beam-{params.beam}_"
|
||||
f"max-contexts-{params.max_contexts}_"
|
||||
f"max-states-{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
key = f"num-active-paths-{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
return {key: decode_results}
|
||||
@ -665,7 +665,7 @@ def save_asr_output(
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
recogs_filename = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"recogs-{test_set_name}_{key}_{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
@ -685,11 +685,11 @@ def save_wer_results(
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"errs-{test_set_name}_{key}_{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
fd, f"{test_set_name}-{key}", results, enable_log=True
|
||||
fd, f"{test_set_name}_{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
@ -698,7 +698,7 @@ def save_wer_results(
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
|
||||
wer_filename = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary_{test_set_name}_{key}_{params.suffix}.txt"
|
||||
)
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
@ -729,9 +729,9 @@ def main():
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||
|
||||
assert params.causal, params.causal
|
||||
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user