mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
add timestamps for ctc-decode in aishell
This commit is contained in:
parent
abd9437e6d
commit
1ab9421b23
736
egs/aishell/ASR/zipformer/ctc_decode.py
Executable file
736
egs/aishell/ASR/zipformer/ctc_decode.py
Executable file
@ -0,0 +1,736 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Liyong Guo,
|
||||
# Quandong Wang,
|
||||
# Zengwei Yao,
|
||||
# Zhifeng Han,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) ctc-greedy-search (with cr-ctc)
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 24 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-greedy-search
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.context_graph import ContextGraph, ContextState
|
||||
from icefall.decode import (
|
||||
ctc_greedy_search,
|
||||
ctc_prefix_beam_search,
|
||||
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||
ctc_prefix_beam_search_shallow_fussion,
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder_no_ngram,
|
||||
rescore_with_attention_decoder_with_ngram,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
DecodingResults,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts_and_timestamps_withoutref,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
parse_hyp_and_timestamp_ch,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_char",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="ctc-greedy-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ilme-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for the internal language model estimation.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.causal:
|
||||
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||
pad_len = 30
|
||||
feature_lens += pad_len
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, pad_len),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
x, x_lens = model.encoder_embed(feature, feature_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1:
|
||||
res = ctc_greedy_search(
|
||||
ctc_output=ctc_output,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
return_timestamps = True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
hyps, timestamps = parse_hyp_and_timestamp_ch(
|
||||
res=res,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
word_table = lexicon.token_table,
|
||||
# frame_shift_ms=params.frame_shift_ms,
|
||||
)
|
||||
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
return {"ctc-greedy-search_" + key: (hyps, timestamps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key += f"_beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ilme_scale_{params.ilme_scale}"
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}_" + key: (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
with_timestamp: bool = False,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
with_timestamp:
|
||||
Whether to decode with timestamp.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains 4 elements:
|
||||
Respectively, they are cut_id, the reference transcript, the predicted result and the decoded_timestamps.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list("".join(text.split())) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
if with_timestamp:
|
||||
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(hyps)
|
||||
for cut_id, hyp_words, ref_text, time_hyp in zip(
|
||||
cut_ids, hyps, texts, timestamps_hyp
|
||||
):
|
||||
this_batch.append((cut_id, ref_text, hyp_words, time_hyp))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
else:
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts_and_timestamps_withoutref(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
result_without_timestamp = [(res[0], res[1], res[2]) for res in results]
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
result_without_timestamp,
|
||||
enable_log=True,
|
||||
compute_CER=True,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
# params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
) # only support ctc-greedy-search
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"_ilme_scale_{params.ilme_scale}"
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
params.device = device
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
|
||||
def remove_short_utt(c: Cut):
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
if T <= 0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
|
||||
)
|
||||
return T > 0
|
||||
|
||||
dev_cuts = aishell.valid_cuts()
|
||||
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||
dev_dl = aishell.valid_dataloaders(dev_cuts)
|
||||
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_cuts = test_cuts.filter(remove_short_utt)
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["dev", "test"]
|
||||
test_dls = [dev_dl, test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
with_timestamp=True,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -64,6 +64,7 @@ from asr_datamodule import AishellAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset import SpecAugment
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import AsrModel
|
||||
@ -239,6 +240,27 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-transducer",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If True, use Transducer head.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ctc",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use CTC head.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-cr-ctc",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use consistency-regularized CTC.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -379,6 +401,27 @@ def get_parser():
|
||||
with this parameter before adding to the final loss.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ctc-loss-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Scale for CTC loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cr-loss-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Scale for consistency-regularization loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--time-mask-ratio",
|
||||
type=float,
|
||||
default=2.5,
|
||||
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -582,8 +625,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
if params.use_transducer:
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
else:
|
||||
decoder = None
|
||||
joiner = None
|
||||
|
||||
model = AsrModel(
|
||||
encoder_embed=encoder_embed,
|
||||
@ -593,9 +641,27 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
encoder_dim=int(max(params.encoder_dim.split(","))),
|
||||
decoder_dim=params.decoder_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
use_transducer=params.use_transducer,
|
||||
use_ctc=params.use_ctc,
|
||||
)
|
||||
return model
|
||||
|
||||
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
||||
num_frame_masks = int(10 * params.time_mask_ratio)
|
||||
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
||||
logging.info(
|
||||
f"num_frame_masks: {num_frame_masks}, "
|
||||
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
|
||||
)
|
||||
spec_augment = SpecAugment(
|
||||
time_warp_factor=0, # Do time warping in model.py
|
||||
num_frame_masks=num_frame_masks, # default: 10
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
|
||||
)
|
||||
return spec_augment
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
@ -722,6 +788,7 @@ def compute_loss(
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
@ -738,6 +805,8 @@ def compute_loss(
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
spec_augment:
|
||||
The SpecAugment instance used only when use_cr_ctc is True.
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
@ -757,6 +826,21 @@ def compute_loss(
|
||||
y = graph_compiler.texts_to_ids(texts)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
use_cr_ctc = params.use_cr_ctc
|
||||
use_spec_aug = use_cr_ctc and is_training
|
||||
if use_spec_aug:
|
||||
supervision_intervals = batch["supervisions"]
|
||||
supervision_segments = torch.stack(
|
||||
[
|
||||
supervision_intervals["sequence_idx"],
|
||||
supervision_intervals["start_frame"],
|
||||
supervision_intervals["num_frames"],
|
||||
],
|
||||
dim=1,
|
||||
) # shape: (S, 3)
|
||||
else:
|
||||
supervision_segments = None
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
losses = model(
|
||||
x=feature,
|
||||
@ -765,25 +849,37 @@ def compute_loss(
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
use_cr_ctc=use_cr_ctc,
|
||||
use_spec_aug=use_spec_aug,
|
||||
spec_augment=spec_augment,
|
||||
supervision_segments=supervision_segments,
|
||||
time_warp_factor=params.spec_aug_time_warp_factor,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
simple_loss, pruned_loss, ctc_loss, _, cr_loss = losses[:5]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
loss = 0.0
|
||||
|
||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
if params.use_transducer:
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
if params.use_ctc:
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
if use_cr_ctc:
|
||||
loss += params.cr_loss_scale * cr_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
@ -793,8 +889,13 @@ def compute_loss(
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if params.use_transducer:
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if params.use_ctc:
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
if params.use_cr_ctc:
|
||||
info["cr_loss"] = cr_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -842,6 +943,7 @@ def train_one_epoch(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -868,6 +970,8 @@ def train_one_epoch(
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
spec_augment:
|
||||
The SpecAugment instance used only when use_cr_ctc is True.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
tb_writer:
|
||||
@ -917,6 +1021,7 @@ def train_one_epoch(
|
||||
graph_compiler=graph_compiler,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
spec_augment=spec_augment,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
@ -1082,6 +1187,9 @@ def run(rank, world_size, args):
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
if not params.use_transducer:
|
||||
params.ctc_loss_scale = 1.0
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -1090,6 +1198,12 @@ def run(rank, world_size, args):
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.use_cr_ctc:
|
||||
assert not params.enable_spec_aug # we will do spec_augment in model.py
|
||||
spec_augment = get_spec_augment(params)
|
||||
else:
|
||||
spec_augment = None
|
||||
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
@ -1199,6 +1313,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
spec_augment=spec_augment,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
@ -1226,6 +1341,7 @@ def run(rank, world_size, args):
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
spec_augment=spec_augment,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
@ -1292,6 +1408,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
params: AttributeDict,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -1309,6 +1426,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
graph_compiler=graph_compiler,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
spec_augment=spec_augment,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.zero_grad()
|
||||
|
@ -26,7 +26,7 @@ import torch
|
||||
from icefall.context_graph import ContextGraph, ContextState
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||
from icefall.utils import add_eos, add_sos, get_texts
|
||||
from icefall.utils import DecodingResults, add_eos, add_sos, get_texts
|
||||
|
||||
DEFAULT_LM_SCALE = [
|
||||
0.01,
|
||||
@ -1485,25 +1485,52 @@ def ctc_greedy_search(
|
||||
ctc_output: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
blank_id: int = 0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""CTC greedy search.
|
||||
|
||||
Args:
|
||||
ctc_output: (batch, seq_len, vocab_size)
|
||||
encoder_out_lens: (batch,)
|
||||
Returns:
|
||||
List[List[int]]: greedy search result
|
||||
Union[List[List[int]], DecodingResults]: greedy search result
|
||||
"""
|
||||
batch = ctc_output.shape[0]
|
||||
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
|
||||
hyps = [
|
||||
torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch)
|
||||
]
|
||||
|
||||
hyps = [h[h != blank_id].tolist() for h in hyps]
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
hyps = [[] for _ in range(batch)]
|
||||
# timestamps[n][i] is the frame index after subsampling
|
||||
# on which hyp[n][i] is decoded
|
||||
timestamps = [[] for _ in range(batch)]
|
||||
# scores[n][i] is the logits on which hyp[n][i] is decoded
|
||||
scores = [[] for _ in range(batch)]
|
||||
|
||||
for i in range(batch):
|
||||
cur_pos = -1
|
||||
last = -1
|
||||
next_other = torch.where(index[i,0 : encoder_out_lens[i]] != last)[0]
|
||||
flag = False # whether decoding words
|
||||
while next_other.size(0) > 0:
|
||||
cur_pos = cur_pos + 1 + next_other[0].item()
|
||||
last = index[i,cur_pos].item()
|
||||
if flag: # last word decoding finished
|
||||
timestamps[i][-1] = timestamps[i][-1] + (cur_pos - 1, )
|
||||
if last != blank_id:
|
||||
hyps[i].append(last)
|
||||
timestamps[i].append((cur_pos,))
|
||||
scores[i].append(ctc_output[i, cur_pos, last].item())
|
||||
flag = True # Decoding words
|
||||
else:
|
||||
flag = False
|
||||
|
||||
next_other = torch.where(index[i,cur_pos + 1 : encoder_out_lens[i]] != last)[0]
|
||||
if flag:
|
||||
timestamps[i][-1] = timestamps[i][-1] + (encoder_out_lens[i].item() - 1, )
|
||||
|
||||
if not return_timestamps:
|
||||
return hyps
|
||||
else:
|
||||
return DecodingResults(hyps = hyps, timestamps = timestamps, scores = scores)
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
|
@ -360,7 +360,7 @@ class KeywordResult:
|
||||
class DecodingResults:
|
||||
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||
# is decoded
|
||||
timestamps: List[List[int]]
|
||||
timestamps: List[List[Union[int, Tuple[int, int]]]]
|
||||
|
||||
# hyps[i] is the recognition results, i.e., word IDs or token IDs
|
||||
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||
@ -583,6 +583,40 @@ def store_transcripts_and_timestamps(
|
||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||
|
||||
def store_transcripts_and_timestamps_withoutref(
|
||||
filename: Pathlike,
|
||||
texts: Iterable[Tuple[str, List[str], List[str], List[Tuple[float]]]],
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts as well as their timestamps
|
||||
to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf8") as f:
|
||||
for cut_id, ref, hyp, time_hyp in texts:
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
if len(time_hyp) > 0:
|
||||
if isinstance(time_hyp[0], tuple):
|
||||
# each element is <start, end> pair
|
||||
s = (
|
||||
"["
|
||||
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
|
||||
+ "]"
|
||||
)
|
||||
else:
|
||||
# each element is a float number
|
||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
@ -1839,6 +1873,30 @@ def convert_timestamp(
|
||||
|
||||
return time
|
||||
|
||||
def convert_timestamp_duration(
|
||||
frames: List[Tuple[int, int]],
|
||||
subsampling_factor: int,
|
||||
frame_shift_ms: float = 10,
|
||||
) -> List[Tuple[float, float]]:
|
||||
"""Convert frame numbers to time (in seconds) given subsampling factor
|
||||
and frame shift (in milliseconds).
|
||||
|
||||
Args:
|
||||
frames:
|
||||
A list of frame numbers after subsampling.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
frame_shift_ms:
|
||||
Frame shift in milliseconds between two contiguous frames.
|
||||
Return:
|
||||
Return the time in seconds corresponding to each given frame.
|
||||
"""
|
||||
frame_shift = frame_shift_ms / 1000.0
|
||||
time = []
|
||||
for f_start, f_end in frames:
|
||||
time.append((round(f_start * subsampling_factor * frame_shift, ndigits=3), round(f_end * subsampling_factor * frame_shift, ndigits=3)))
|
||||
|
||||
return time
|
||||
|
||||
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||
"""
|
||||
@ -1924,6 +1982,43 @@ def parse_hyp_and_timestamp(
|
||||
|
||||
return hyps, timestamps
|
||||
|
||||
def parse_hyp_and_timestamp_ch(
|
||||
res: DecodingResults,
|
||||
subsampling_factor: int,
|
||||
word_table: k2.SymbolTable,
|
||||
frame_shift_ms: float = 10,
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
"""Parse hypothesis and timestamps of Chinese characters.
|
||||
|
||||
Args:
|
||||
res:
|
||||
A DecodingResults object.
|
||||
subsampling_factor:
|
||||
The integer subsampling factor.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
frame_shift_ms:
|
||||
The float frame shift used for feature extraction.
|
||||
|
||||
Returns:
|
||||
Return a list of hypothesis and timestamp.
|
||||
"""
|
||||
hyps = []
|
||||
timestamps = []
|
||||
|
||||
N = len(res.hyps)
|
||||
assert len(res.timestamps) == N, (len(res.timestamps), N)
|
||||
assert word_table is not None
|
||||
|
||||
for i in range(N):
|
||||
time = convert_timestamp_duration(res.timestamps[i], subsampling_factor, frame_shift_ms)
|
||||
words_decoded = [word_table[idx] for idx in res.hyps[i]]
|
||||
hyps.append(words_decoded)
|
||||
timestamps.append(time)
|
||||
assert len(time) == len(words_decoded), (len(time), len(words_decoded))
|
||||
|
||||
return hyps, timestamps
|
||||
|
||||
|
||||
# `is_module_available` is copied from
|
||||
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
||||
|
Loading…
x
Reference in New Issue
Block a user