mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
2393 lines
79 KiB
Python
2393 lines
79 KiB
Python
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
|
# Mingshuang Luo,
|
|
# Zengwei Yao)
|
|
#
|
|
# See ../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import argparse
|
|
import collections
|
|
import json
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import random
|
|
import re
|
|
import subprocess
|
|
import warnings
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from shutil import copyfile
|
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
|
|
|
import k2
|
|
import k2.version
|
|
import kaldialign
|
|
import sentencepiece as spm
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
|
|
from packaging import version
|
|
from pypinyin import lazy_pinyin, pinyin
|
|
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from icefall.checkpoint import average_checkpoints
|
|
|
|
Pathlike = Union[str, Path]
|
|
|
|
TORCH_VERSION = version.parse(torch.__version__)
|
|
|
|
|
|
def create_grad_scaler(device="cuda", **kwargs):
|
|
"""
|
|
Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0.
|
|
Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
|
|
|
|
/icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning:
|
|
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use
|
|
`torch.amp.GradScaler('cuda', args...)` instead.
|
|
"""
|
|
if TORCH_VERSION >= version.parse("2.3.0"):
|
|
from torch.amp import GradScaler
|
|
|
|
return GradScaler(device=device, **kwargs)
|
|
else:
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore", category=FutureWarning)
|
|
return torch.cuda.amp.GradScaler(**kwargs)
|
|
|
|
|
|
@contextmanager
|
|
def torch_autocast(device_type="cuda", **kwargs):
|
|
"""
|
|
To fix the following warnings:
|
|
/icefall/egs/librispeech/ASR/zipformer/model.py:323:
|
|
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.
|
|
Please use `torch.amp.autocast('cuda', args...)` instead.
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
"""
|
|
if TORCH_VERSION >= version.parse("2.3.0"):
|
|
# Use new unified API
|
|
with torch.amp.autocast(device_type=device_type, **kwargs):
|
|
yield
|
|
else:
|
|
# Suppress deprecation warning and use old CUDA-specific autocast
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore", category=FutureWarning)
|
|
with torch.cuda.amp.autocast(**kwargs):
|
|
yield
|
|
|
|
|
|
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
|
|
# Fixed: https://github.com/pytorch/pytorch/pull/49853
|
|
# The fix was included in v1.9.0
|
|
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
|
|
def is_jit_tracing():
|
|
if torch.jit.is_scripting():
|
|
return False
|
|
elif torch.jit.is_tracing():
|
|
return True
|
|
return False
|
|
|
|
|
|
@contextmanager
|
|
def get_executor():
|
|
# We'll either return a process pool or a distributed worker pool.
|
|
# Note that this has to be a context manager because we might use multiple
|
|
# context manager ("with" clauses) inside, and this way everything will
|
|
# free up the resources at the right time.
|
|
try:
|
|
# If this is executed on the CLSP grid, we will try to use the
|
|
# Grid Engine to distribute the tasks.
|
|
# Other clusters can also benefit from that, provided a
|
|
# cluster-specific wrapper.
|
|
# (see https://github.com/pzelasko/plz for reference)
|
|
#
|
|
# The following must be installed:
|
|
# $ pip install dask distributed
|
|
# $ pip install git+https://github.com/pzelasko/plz
|
|
name = subprocess.check_output("hostname -f", shell=True, text=True)
|
|
if name.strip().endswith(".clsp.jhu.edu"):
|
|
import plz
|
|
from distributed import Client
|
|
|
|
with plz.setup_cluster() as cluster:
|
|
cluster.scale(80)
|
|
yield Client(cluster)
|
|
return
|
|
except Exception:
|
|
pass
|
|
# No need to return anything - compute_and_store_features
|
|
# will just instantiate the pool itself.
|
|
yield None
|
|
|
|
|
|
def str2bool(v):
|
|
"""Used in argparse.ArgumentParser.add_argument to indicate
|
|
that a type is a bool type and user can enter
|
|
|
|
- yes, true, t, y, 1, to represent True
|
|
- no, false, f, n, 0, to represent False
|
|
|
|
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
|
"""
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ("yes", "true", "t", "y", "1"):
|
|
return True
|
|
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
|
|
|
|
|
def setup_logger(
|
|
log_filename: Pathlike,
|
|
log_level: str = "info",
|
|
use_console: bool = True,
|
|
) -> None:
|
|
"""Setup log level.
|
|
|
|
Args:
|
|
log_filename:
|
|
The filename to save the log.
|
|
log_level:
|
|
The log level to use, e.g., "debug", "info", "warning", "error",
|
|
"critical"
|
|
use_console:
|
|
True to also print logs to console.
|
|
"""
|
|
now = datetime.now()
|
|
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
|
if dist.is_available() and dist.is_initialized():
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
|
log_filename = f"{log_filename}-{date_time}-{rank}"
|
|
else:
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
log_filename = f"{log_filename}-{date_time}"
|
|
|
|
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
|
|
|
level = logging.ERROR
|
|
if log_level == "debug":
|
|
level = logging.DEBUG
|
|
elif log_level == "info":
|
|
level = logging.INFO
|
|
elif log_level == "warning":
|
|
level = logging.WARNING
|
|
elif log_level == "critical":
|
|
level = logging.CRITICAL
|
|
|
|
logging.basicConfig(
|
|
filename=log_filename,
|
|
format=formatter,
|
|
level=level,
|
|
filemode="w",
|
|
force=True,
|
|
)
|
|
if use_console:
|
|
console = logging.StreamHandler()
|
|
console.setLevel(level)
|
|
console.setFormatter(logging.Formatter(formatter))
|
|
logging.getLogger("").addHandler(console)
|
|
|
|
|
|
class AttributeDict(dict):
|
|
def __getattr__(self, key):
|
|
if key in self:
|
|
return self[key]
|
|
raise AttributeError(f"No such attribute '{key}'")
|
|
|
|
def __setattr__(self, key, value):
|
|
self[key] = value
|
|
|
|
def __delattr__(self, key):
|
|
if key in self:
|
|
del self[key]
|
|
return
|
|
raise AttributeError(f"No such attribute '{key}'")
|
|
|
|
def __str__(self, indent: int = 2):
|
|
tmp = {}
|
|
for k, v in self.items():
|
|
# PosixPath is ont JSON serializable
|
|
if isinstance(v, (pathlib.Path, torch.device, torch.dtype)):
|
|
v = str(v)
|
|
tmp[k] = v
|
|
return json.dumps(tmp, indent=indent, sort_keys=True)
|
|
|
|
|
|
def encode_supervisions(
|
|
supervisions: dict,
|
|
subsampling_factor: int,
|
|
token_ids: Optional[List[List[int]]] = None,
|
|
) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]:
|
|
"""
|
|
Encodes Lhotse's ``batch["supervisions"]`` dict into
|
|
a pair of torch Tensor, and a list of transcription strings or token indexes
|
|
|
|
The supervision tensor has shape ``(batch_size, 3)``.
|
|
Its second dimension contains information about sequence index [0],
|
|
start frames [1] and num frames [2].
|
|
|
|
The batch items might become re-ordered during this operation -- the
|
|
returned tensor and list of strings are guaranteed to be consistent with
|
|
each other.
|
|
"""
|
|
supervision_segments = torch.stack(
|
|
(
|
|
supervisions["sequence_idx"],
|
|
torch.div(
|
|
supervisions["start_frame"],
|
|
subsampling_factor,
|
|
rounding_mode="floor",
|
|
),
|
|
torch.div(
|
|
supervisions["num_frames"],
|
|
subsampling_factor,
|
|
rounding_mode="floor",
|
|
),
|
|
),
|
|
1,
|
|
).to(torch.int32)
|
|
|
|
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
|
supervision_segments = supervision_segments[indices]
|
|
|
|
if token_ids is None:
|
|
texts = supervisions["text"]
|
|
res = [texts[idx] for idx in indices]
|
|
else:
|
|
res = [token_ids[idx] for idx in indices]
|
|
|
|
return supervision_segments, res
|
|
|
|
|
|
def get_texts(
|
|
best_paths: k2.Fsa, return_ragged: bool = False
|
|
) -> Union[List[List[int]], k2.RaggedTensor]:
|
|
"""Extract the texts (as word IDs) from the best-path FSAs.
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful).
|
|
return_ragged:
|
|
True to return a ragged tensor with two axes [utt][word_id].
|
|
False to return a list-of-list word IDs.
|
|
Returns:
|
|
Returns a list of lists of int, containing the label sequences we
|
|
decoded.
|
|
"""
|
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
|
# remove 0's and -1's.
|
|
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
|
# TODO: change arcs.shape() to arcs.shape
|
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
|
|
|
# remove the states and arcs axes.
|
|
aux_shape = aux_shape.remove_axis(1)
|
|
aux_shape = aux_shape.remove_axis(1)
|
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
|
else:
|
|
# remove axis corresponding to states.
|
|
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
|
# remove 0's and -1's.
|
|
aux_labels = aux_labels.remove_values_leq(0)
|
|
|
|
assert aux_labels.num_axes == 2
|
|
if return_ragged:
|
|
return aux_labels
|
|
else:
|
|
return aux_labels.tolist()
|
|
|
|
|
|
def encode_supervisions_otc(
|
|
supervisions: dict,
|
|
subsampling_factor: int,
|
|
token_ids: Optional[List[List[int]]] = None,
|
|
) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]:
|
|
"""
|
|
Encodes Lhotse's ``batch["supervisions"]`` dict into
|
|
a pair of torch Tensor, and a list of transcription strings or token indexes
|
|
|
|
The supervision tensor has shape ``(batch_size, 3)``.
|
|
Its second dimension contains information about sequence index [0],
|
|
start frames [1] and num frames [2].
|
|
|
|
The batch items might become re-ordered during this operation -- the
|
|
returned tensor and list of strings are guaranteed to be consistent with
|
|
each other.
|
|
"""
|
|
supervision_segments = torch.stack(
|
|
(
|
|
supervisions["sequence_idx"],
|
|
torch.div(
|
|
supervisions["start_frame"],
|
|
subsampling_factor,
|
|
rounding_mode="floor",
|
|
),
|
|
torch.div(
|
|
supervisions["num_frames"],
|
|
subsampling_factor,
|
|
rounding_mode="floor",
|
|
),
|
|
),
|
|
1,
|
|
).to(torch.int32)
|
|
|
|
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
|
supervision_segments = supervision_segments[indices]
|
|
|
|
ids = []
|
|
verbatim_texts = []
|
|
sorted_ids = []
|
|
sorted_verbatim_texts = []
|
|
|
|
for cut in supervisions["cut"]:
|
|
id = cut.id
|
|
if hasattr(cut.supervisions[0], "verbatim_text"):
|
|
verbatim_text = cut.supervisions[0].verbatim_text
|
|
else:
|
|
verbatim_text = ""
|
|
ids.append(id)
|
|
verbatim_texts.append(verbatim_text)
|
|
|
|
for index in indices.tolist():
|
|
sorted_ids.append(ids[index])
|
|
sorted_verbatim_texts.append(verbatim_texts[index])
|
|
|
|
if token_ids is None:
|
|
texts = supervisions["text"]
|
|
res = [texts[idx] for idx in indices]
|
|
else:
|
|
res = [token_ids[idx] for idx in indices]
|
|
|
|
return supervision_segments, res, sorted_ids, sorted_verbatim_texts
|
|
|
|
|
|
@dataclass
|
|
class KeywordResult:
|
|
# timestamps[k] contains the frame number on which tokens[k]
|
|
# is decoded
|
|
timestamps: List[int]
|
|
|
|
# hyps is the keyword, i.e., word IDs or token IDs
|
|
hyps: List[int]
|
|
|
|
# The triggered phrase
|
|
phrase: str
|
|
|
|
|
|
@dataclass
|
|
class DecodingResults:
|
|
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
|
# is decoded
|
|
timestamps: List[List[int]]
|
|
|
|
# hyps[i] is the recognition results, i.e., word IDs or token IDs
|
|
# for the i-th utterance with fast_beam_search_nbest_LG.
|
|
hyps: Union[List[List[int]], k2.RaggedTensor]
|
|
|
|
# scores[i][k] contains the log-prob of tokens[i][k]
|
|
scores: Optional[List[List[float]]] = None
|
|
|
|
|
|
def get_texts_with_timestamp(
|
|
best_paths: k2.Fsa, return_ragged: bool = False
|
|
) -> DecodingResults:
|
|
"""Extract the texts (as word IDs) and timestamps (as frame indexes)
|
|
from the best-path FSAs.
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful).
|
|
return_ragged:
|
|
True to return a ragged tensor with two axes [utt][word_id].
|
|
False to return a list-of-list word IDs.
|
|
Returns:
|
|
Returns a list of lists of int, containing the label sequences we
|
|
decoded.
|
|
"""
|
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
|
all_aux_shape = (
|
|
best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
|
|
)
|
|
all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
|
|
# remove 0's and -1's.
|
|
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
|
# TODO: change arcs.shape() to arcs.shape
|
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
|
# remove the states and arcs axes.
|
|
aux_shape = aux_shape.remove_axis(1)
|
|
aux_shape = aux_shape.remove_axis(1)
|
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
|
else:
|
|
# remove axis corresponding to states.
|
|
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
|
all_aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
|
# remove 0's and -1's.
|
|
aux_labels = all_aux_labels.remove_values_leq(0)
|
|
|
|
assert aux_labels.num_axes == 2
|
|
|
|
timestamps = []
|
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
|
for p in range(all_aux_labels.dim0):
|
|
time = []
|
|
for i, arc in enumerate(all_aux_labels[p].tolist()):
|
|
if len(arc) == 1 and arc[0] > 0:
|
|
time.append(i)
|
|
timestamps.append(time)
|
|
else:
|
|
for labels in all_aux_labels.tolist():
|
|
time = [i for i, v in enumerate(labels) if v > 0]
|
|
timestamps.append(time)
|
|
|
|
return DecodingResults(
|
|
timestamps=timestamps,
|
|
hyps=aux_labels if return_ragged else aux_labels.tolist(),
|
|
)
|
|
|
|
|
|
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
|
"""Extract labels or aux_labels from the best-path FSAs.
|
|
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful).
|
|
kind:
|
|
Possible values are: "labels" and "aux_labels". Caution: When it is
|
|
"labels", the resulting alignments contain repeats.
|
|
Returns:
|
|
Returns a list of lists of int, containing the token sequences we
|
|
decoded. For `ans[i]`, its length equals to the number of frames
|
|
after subsampling of the i-th utterance in the batch.
|
|
|
|
Example:
|
|
When `kind` is `labels`, one possible alignment example is (with
|
|
repeats)::
|
|
|
|
c c c blk a a blk blk t t t blk blk
|
|
|
|
If `kind` is `aux_labels`, the above example changes to::
|
|
|
|
c blk blk blk a blk blk blk t blk blk blk blk
|
|
|
|
"""
|
|
assert kind in ("labels", "aux_labels")
|
|
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
|
|
token_shape = best_paths.arcs.shape().remove_axis(1)
|
|
# token_shape has axes [fsa][arc]
|
|
tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
|
|
tokens = tokens.remove_values_eq(-1)
|
|
return tokens.tolist()
|
|
|
|
|
|
def save_alignments(
|
|
alignments: Dict[str, List[int]],
|
|
subsampling_factor: int,
|
|
filename: str,
|
|
) -> None:
|
|
"""Save alignments to a file.
|
|
|
|
Args:
|
|
alignments:
|
|
A dict containing alignments. Keys of the dict are utterances and
|
|
values are the corresponding framewise alignments after subsampling.
|
|
subsampling_factor:
|
|
The subsampling factor of the model.
|
|
filename:
|
|
Path to save the alignments.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
ali_dict = {
|
|
"subsampling_factor": subsampling_factor,
|
|
"alignments": alignments,
|
|
}
|
|
torch.save(ali_dict, filename)
|
|
|
|
|
|
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
|
"""Load alignments from a file.
|
|
|
|
Args:
|
|
filename:
|
|
Path to the file containing alignment information.
|
|
The file should be saved by :func:`save_alignments`.
|
|
Returns:
|
|
Return a tuple containing:
|
|
- subsampling_factor: The subsampling_factor used to compute
|
|
the alignments.
|
|
- alignments: A dict containing utterances and their corresponding
|
|
framewise alignment, after subsampling.
|
|
"""
|
|
ali_dict = torch.load(filename, weights_only=False)
|
|
subsampling_factor = ali_dict["subsampling_factor"]
|
|
alignments = ali_dict["alignments"]
|
|
return subsampling_factor, alignments
|
|
|
|
|
|
def store_transcripts(
|
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
|
) -> None:
|
|
"""Save predicted results and reference transcripts to a file.
|
|
|
|
Args:
|
|
filename:
|
|
File to save the results to.
|
|
texts:
|
|
An iterable of tuples. The first element is the cur_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
|
strings.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
with open(filename, "w", encoding="utf8") as f:
|
|
for cut_id, ref, hyp in texts:
|
|
if char_level:
|
|
ref = list("".join(ref))
|
|
hyp = list("".join(hyp))
|
|
print(f"{cut_id}:\tref={ref}", file=f)
|
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
|
|
|
|
|
def store_transcripts_and_timestamps(
|
|
filename: Pathlike,
|
|
texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]],
|
|
) -> None:
|
|
"""Save predicted results and reference transcripts as well as their timestamps
|
|
to a file.
|
|
|
|
Args:
|
|
filename:
|
|
File to save the results to.
|
|
texts:
|
|
An iterable of tuples. The first element is the cur_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
with open(filename, "w", encoding="utf8") as f:
|
|
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
|
print(f"{cut_id}:\tref={ref}", file=f)
|
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
|
|
|
if len(time_ref) > 0:
|
|
if isinstance(time_ref[0], tuple):
|
|
# each element is <start, end> pair
|
|
s = (
|
|
"["
|
|
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref])
|
|
+ "]"
|
|
)
|
|
else:
|
|
# each element is a float number
|
|
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
|
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
|
|
|
if len(time_hyp) > 0:
|
|
if isinstance(time_hyp[0], tuple):
|
|
# each element is <start, end> pair
|
|
s = (
|
|
"["
|
|
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
|
|
+ "]"
|
|
)
|
|
else:
|
|
# each element is a float number
|
|
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
|
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
|
|
|
|
|
def write_error_stats(
|
|
f: TextIO,
|
|
test_set_name: str,
|
|
results: List[Tuple[str, str]],
|
|
enable_log: bool = True,
|
|
compute_CER: bool = False,
|
|
sclite_mode: bool = False,
|
|
) -> float:
|
|
"""Write statistics based on predicted results and reference transcripts.
|
|
|
|
It will write the following to the given file:
|
|
|
|
- WER
|
|
- number of insertions, deletions, substitutions, corrects and total
|
|
reference words. For example::
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
|
reference words (2337 correct)
|
|
|
|
- The difference between the reference transcript and predicted result.
|
|
An instance is given below::
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
|
|
|
The above example shows that the reference word is `EDISON`,
|
|
but it is predicted to `ADDISON` (a substitution error).
|
|
|
|
Another example is::
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK
|
|
|
|
The reference word `SIR` is missing in the predicted
|
|
results (a deletion error).
|
|
results:
|
|
An iterable of tuples. The first element is the cut_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
enable_log:
|
|
If True, also print detailed WER to the console.
|
|
Otherwise, it is written only to the given file.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
ins: Dict[str, int] = defaultdict(int)
|
|
dels: Dict[str, int] = defaultdict(int)
|
|
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
|
|
if compute_CER:
|
|
for i, res in enumerate(results):
|
|
cut_id, ref, hyp = res
|
|
ref = list("".join(ref))
|
|
hyp = list("".join(hyp))
|
|
results[i] = (cut_id, ref, hyp)
|
|
|
|
for cut_id, ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR:
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
elif hyp_word == ERR:
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
elif hyp_word != ref_word:
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
ref_len = sum([len(r) for _, r, _ in results])
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
|
|
if enable_log:
|
|
logging.info(
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
f"{del_errs} del, {sub_errs} sub ]"
|
|
)
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
print(
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
|
f"words ({num_corr} correct)",
|
|
file=f,
|
|
)
|
|
print(
|
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
for cut_id, ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
combine_successive_errors = True
|
|
if combine_successive_errors:
|
|
ali = [[[x], [y]] for x, y in ali]
|
|
for i in range(len(ali) - 1):
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
ali[i] = [[], []]
|
|
ali = [
|
|
[
|
|
list(filter(lambda a: a != ERR, x)),
|
|
list(filter(lambda a: a != ERR, y)),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
ali = list(filter(lambda x: x != [[], []], ali))
|
|
ali = [
|
|
[
|
|
ERR if x == [] else " ".join(x),
|
|
ERR if y == [] else " ".join(y),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
|
|
print(
|
|
f"{cut_id}:\t"
|
|
+ " ".join(
|
|
(
|
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
|
for ref_word, hyp_word in ali
|
|
)
|
|
),
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
|
print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("DELETIONS: count ref", file=f)
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
print(f"{count} {ref}", file=f)
|
|
|
|
print("", file=f)
|
|
print("INSERTIONS: count hyp", file=f)
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
print(f"{count} {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
|
for _, word, counts in sorted(
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
|
):
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
|
ref_count = corr + ref_sub + dels
|
|
hyp_count = corr + hyp_sub + ins
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
return float(tot_err_rate)
|
|
|
|
|
|
def write_error_stats_with_timestamps(
|
|
f: TextIO,
|
|
test_set_name: str,
|
|
results: List[
|
|
Tuple[
|
|
str,
|
|
List[str],
|
|
List[str],
|
|
List[Union[float, Tuple[float, float]]],
|
|
List[Union[float, Tuple[float, float]]],
|
|
]
|
|
],
|
|
enable_log: bool = True,
|
|
with_end_time: bool = False,
|
|
) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]:
|
|
"""Write statistics based on predicted results and reference transcripts
|
|
as well as their timestamps.
|
|
|
|
It will write the following to the given file:
|
|
|
|
- WER
|
|
- number of insertions, deletions, substitutions, corrects and total
|
|
reference words. For example::
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
|
reference words (2337 correct)
|
|
|
|
- The difference between the reference transcript and predicted result.
|
|
An instance is given below::
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
|
|
|
The above example shows that the reference word is `EDISON`,
|
|
but it is predicted to `ADDISON` (a substitution error).
|
|
|
|
Another example is::
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK
|
|
|
|
The reference word `SIR` is missing in the predicted
|
|
results (a deletion error).
|
|
results:
|
|
An iterable of tuples. The first element is the cur_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
enable_log:
|
|
If True, also print detailed WER to the console.
|
|
Otherwise, it is written only to the given file.
|
|
with_end_time:
|
|
Whether use end timestamps.
|
|
|
|
Returns:
|
|
Return total word error rate and mean delay.
|
|
"""
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
ins: Dict[str, int] = defaultdict(int)
|
|
dels: Dict[str, int] = defaultdict(int)
|
|
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
# Compute mean alignment delay on the correct words
|
|
all_delay = []
|
|
for cut_id, ref, hyp, time_ref, time_hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
has_time = len(time_ref) > 0 and len(time_hyp) > 0
|
|
if has_time:
|
|
# pointer to timestamp_hyp
|
|
p_hyp = 0
|
|
# pointer to timestamp_ref
|
|
p_ref = 0
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR:
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
if has_time:
|
|
p_hyp += 1
|
|
elif hyp_word == ERR:
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
if has_time:
|
|
p_ref += 1
|
|
elif hyp_word != ref_word:
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
if has_time:
|
|
p_hyp += 1
|
|
p_ref += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
if has_time:
|
|
if with_end_time:
|
|
all_delay.append(
|
|
(
|
|
time_hyp[p_hyp][0] - time_ref[p_ref][0],
|
|
time_hyp[p_hyp][1] - time_ref[p_ref][1],
|
|
)
|
|
)
|
|
else:
|
|
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
|
p_hyp += 1
|
|
p_ref += 1
|
|
if has_time:
|
|
assert p_hyp == len(hyp), (p_hyp, len(hyp))
|
|
assert p_ref == len(ref), (p_ref, len(ref))
|
|
|
|
ref_len = sum([len(r) for _, r, _, _, _ in results])
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len))
|
|
|
|
if with_end_time:
|
|
mean_delay = (float("inf"), float("inf"))
|
|
var_delay = (float("inf"), float("inf"))
|
|
else:
|
|
mean_delay = float("inf")
|
|
var_delay = float("inf")
|
|
num_delay = len(all_delay)
|
|
if num_delay > 0:
|
|
if with_end_time:
|
|
all_delay_start = [i[0] for i in all_delay]
|
|
mean_delay_start = sum(all_delay_start) / num_delay
|
|
var_delay_start = (
|
|
sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay
|
|
)
|
|
|
|
all_delay_end = [i[1] for i in all_delay]
|
|
mean_delay_end = sum(all_delay_end) / num_delay
|
|
var_delay_end = (
|
|
sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay
|
|
)
|
|
|
|
mean_delay = (
|
|
float("%.3f" % mean_delay_start),
|
|
float("%.3f" % mean_delay_end),
|
|
)
|
|
var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end))
|
|
else:
|
|
mean_delay = sum(all_delay) / num_delay
|
|
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
|
|
mean_delay = float("%.3f" % mean_delay)
|
|
var_delay = float("%.3f" % var_delay)
|
|
|
|
if enable_log:
|
|
logging.info(
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
f"{del_errs} del, {sub_errs} sub ]"
|
|
)
|
|
logging.info(
|
|
f"[{test_set_name}] %symbol-delay mean (s): "
|
|
f"{mean_delay}, variance: {var_delay} " # noqa
|
|
f"computed on {num_delay} correct words"
|
|
)
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
print(
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
|
f"words ({num_corr} correct)",
|
|
file=f,
|
|
)
|
|
print(
|
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
for cut_id, ref, hyp, _, _ in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
combine_successive_errors = True
|
|
if combine_successive_errors:
|
|
ali = [[[x], [y]] for x, y in ali]
|
|
for i in range(len(ali) - 1):
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
ali[i] = [[], []]
|
|
ali = [
|
|
[
|
|
list(filter(lambda a: a != ERR, x)),
|
|
list(filter(lambda a: a != ERR, y)),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
ali = list(filter(lambda x: x != [[], []], ali))
|
|
ali = [
|
|
[
|
|
ERR if x == [] else " ".join(x),
|
|
ERR if y == [] else " ".join(y),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
|
|
print(
|
|
f"{cut_id}:\t"
|
|
+ " ".join(
|
|
(
|
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
|
for ref_word, hyp_word in ali
|
|
)
|
|
),
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
|
print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("DELETIONS: count ref", file=f)
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
print(f"{count} {ref}", file=f)
|
|
|
|
print("", file=f)
|
|
print("INSERTIONS: count hyp", file=f)
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
print(f"{count} {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
|
for _, word, counts in sorted(
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
|
):
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
|
ref_count = corr + ref_sub + dels
|
|
hyp_count = corr + hyp_sub + ins
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
|
|
|
|
|
def write_surt_error_stats(
|
|
f: TextIO,
|
|
test_set_name: str,
|
|
results: List[Tuple[str, str]],
|
|
enable_log: bool = True,
|
|
num_channels: int = 2,
|
|
) -> float:
|
|
"""Write statistics based on predicted results and reference transcripts for SURT
|
|
multi-talker ASR systems. The difference between this and the `write_error_stats`
|
|
is that this function finds the optimal speaker-agnostic WER using the ``meeteval``
|
|
toolkit.
|
|
|
|
Args:
|
|
f: File to write the statistics to.
|
|
test_set_name: Name of the test set.
|
|
results: List of tuples containing the utterance ID and the predicted
|
|
transcript.
|
|
enable_log: Whether to enable logging.
|
|
num_channels: Number of output channels/branches. Defaults to 2.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
from meeteval.wer import wer
|
|
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
ins: Dict[str, int] = defaultdict(int)
|
|
dels: Dict[str, int] = defaultdict(int)
|
|
ref_lens: List[int] = []
|
|
|
|
print(
|
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
for cut_id, ref, hyp in results:
|
|
# First compute the optimal assignment of references to output channels
|
|
orc_wer = wer.orc_word_error_rate(ref, hyp)
|
|
assignment = orc_wer.assignment
|
|
refs = [[] for _ in range(num_channels)]
|
|
# Assign references to channels
|
|
for i, ref_text in zip(assignment, ref):
|
|
refs[i] += ref_text.split()
|
|
hyps = [hyp_text.split() for hyp_text in hyp]
|
|
# Now compute the WER for each channel
|
|
for ref_c, hyp_c in zip(refs, hyps):
|
|
ref_lens.append(len(ref_c))
|
|
ali = kaldialign.align(ref_c, hyp_c, ERR)
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR:
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
elif hyp_word == ERR:
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
elif hyp_word != ref_word:
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
combine_successive_errors = True
|
|
if combine_successive_errors:
|
|
ali = [[[x], [y]] for x, y in ali]
|
|
for i in range(len(ali) - 1):
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
ali[i] = [[], []]
|
|
ali = [
|
|
[
|
|
list(filter(lambda a: a != ERR, x)),
|
|
list(filter(lambda a: a != ERR, y)),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
ali = list(filter(lambda x: x != [[], []], ali))
|
|
ali = [
|
|
[
|
|
ERR if x == [] else " ".join(x),
|
|
ERR if y == [] else " ".join(y),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
|
|
print(
|
|
f"{cut_id}:\t"
|
|
+ " ".join(
|
|
(
|
|
(
|
|
ref_word
|
|
if ref_word == hyp_word
|
|
else f"({ref_word}->{hyp_word})"
|
|
)
|
|
for ref_word, hyp_word in ali
|
|
)
|
|
),
|
|
file=f,
|
|
)
|
|
ref_len = sum(ref_lens)
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
|
|
if enable_log:
|
|
logging.info(
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
f"{del_errs} del, {sub_errs} sub ]"
|
|
)
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
print(
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
|
f"words ({num_corr} correct)",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
|
print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("DELETIONS: count ref", file=f)
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
print(f"{count} {ref}", file=f)
|
|
|
|
print("", file=f)
|
|
print("INSERTIONS: count hyp", file=f)
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
print(f"{count} {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
|
for _, word, counts in sorted(
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
|
):
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
|
ref_count = corr + ref_sub + dels
|
|
hyp_count = corr + hyp_sub + ins
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
return float(tot_err_rate)
|
|
|
|
|
|
class MetricsTracker(collections.defaultdict):
|
|
def __init__(self):
|
|
# Passing the type 'int' to the base-class constructor
|
|
# makes undefined items default to int() which is zero.
|
|
# This class will play a role as metrics tracker.
|
|
# It can record many metrics, including but not limited to loss.
|
|
super(MetricsTracker, self).__init__(int)
|
|
|
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v
|
|
for k, v in other.items():
|
|
if v - v == 0:
|
|
ans[k] = ans[k] + v
|
|
return ans
|
|
|
|
def __mul__(self, alpha: float) -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v * alpha
|
|
return ans
|
|
|
|
def __str__(self) -> str:
|
|
ans_frames = ""
|
|
ans_utterances = ""
|
|
for k, v in self.norm_items():
|
|
norm_value = "%.4g" % v
|
|
if "utt_" not in k:
|
|
ans_frames += str(k) + "=" + str(norm_value) + ", "
|
|
else:
|
|
ans_utterances += str(k) + "=" + str(norm_value)
|
|
if k == "utt_duration":
|
|
ans_utterances += " frames, "
|
|
elif k == "utt_pad_proportion":
|
|
ans_utterances += ", "
|
|
else:
|
|
raise ValueError(f"Unexpected key: {k}")
|
|
frames = "%.2f" % self["frames"]
|
|
ans_frames += "over " + str(frames) + " frames. "
|
|
if ans_utterances != "":
|
|
utterances = "%.2f" % self["utterances"]
|
|
ans_utterances += "over " + str(utterances) + " utterances."
|
|
|
|
return ans_frames + ans_utterances
|
|
|
|
def norm_items(self) -> List[Tuple[str, float]]:
|
|
"""
|
|
Returns a list of pairs, like:
|
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
|
"""
|
|
num_frames = self["frames"] if "frames" in self else 1
|
|
num_utterances = self["utterances"] if "utterances" in self else 1
|
|
ans = []
|
|
for k, v in self.items():
|
|
if k == "frames" or k == "utterances":
|
|
continue
|
|
norm_value = (
|
|
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
|
|
)
|
|
ans.append((k, norm_value))
|
|
return ans
|
|
|
|
def reduce(self, device):
|
|
"""
|
|
Reduce using torch.distributed, which I believe ensures that
|
|
all processes get the total.
|
|
"""
|
|
keys = sorted(self.keys())
|
|
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
for k, v in zip(keys, s.cpu().tolist()):
|
|
self[k] = v
|
|
|
|
def write_summary(
|
|
self,
|
|
tb_writer: SummaryWriter,
|
|
prefix: str,
|
|
batch_idx: int,
|
|
) -> None:
|
|
"""Add logging information to a TensorBoard writer.
|
|
|
|
Args:
|
|
tb_writer: a TensorBoard writer
|
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
|
or "train/current_"
|
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
|
"""
|
|
for k, v in self.norm_items():
|
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
|
|
|
|
|
def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
|
|
"""Prepend a value to the beginning of each sublist or append a value.
|
|
to the end of each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
value:
|
|
The value to prepend or append.
|
|
direction:
|
|
It can be either "left" or "right". If it is "left", we
|
|
prepend the value to the beginning of each sublist;
|
|
if it is "right", we append the value to the end of each
|
|
sublist.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, whose sublists either start with
|
|
or end with the given value.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> concat(a, value=0, direction="left")
|
|
[ [ 0 1 3 ] [ 0 5 ] ]
|
|
>>> concat(a, value=0, direction="right")
|
|
[ [ 1 3 0 ] [ 5 0 ] ]
|
|
|
|
"""
|
|
dtype = ragged.dtype
|
|
device = ragged.device
|
|
|
|
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
|
pad_values = torch.full(
|
|
size=(ragged.tot_size(0), 1),
|
|
fill_value=value,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
pad = k2.RaggedTensor(pad_values)
|
|
|
|
if direction == "left":
|
|
ans = k2.ragged.cat([pad, ragged], axis=1)
|
|
elif direction == "right":
|
|
ans = k2.ragged.cat([ragged, pad], axis=1)
|
|
else:
|
|
raise ValueError(
|
|
f'Unsupported direction: {direction}. " \
|
|
"Expect either "left" or "right"'
|
|
)
|
|
return ans
|
|
|
|
|
|
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
|
|
"""Add SOS to each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
sos_id:
|
|
The ID of the SOS symbol.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, where each sublist starts with SOS.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> add_sos(a, sos_id=0)
|
|
[ [ 0 1 3 ] [ 0 5 ] ]
|
|
|
|
"""
|
|
return concat(ragged, sos_id, direction="left")
|
|
|
|
|
|
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
|
"""Add EOS to each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
eos_id:
|
|
The ID of the EOS symbol.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, where each sublist ends with EOS.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> add_eos(a, eos_id=0)
|
|
[ [ 1 3 0 ] [ 5 0 ] ]
|
|
|
|
"""
|
|
return concat(ragged, eos_id, direction="right")
|
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
lengths:
|
|
A 1-D tensor containing sentence lengths.
|
|
max_len:
|
|
The length of masks.
|
|
Returns:
|
|
Return a 2-D bool tensor, where masked positions
|
|
are filled with `True` and non-masked positions are
|
|
filled with `False`.
|
|
|
|
>>> lengths = torch.tensor([1, 3, 2, 5])
|
|
>>> make_pad_mask(lengths)
|
|
tensor([[False, True, True, True, True],
|
|
[False, False, False, True, True],
|
|
[False, False, True, True, True],
|
|
[False, False, False, False, False]])
|
|
"""
|
|
assert lengths.ndim == 1, lengths.ndim
|
|
max_len = max(max_len, lengths.max())
|
|
n = lengths.size(0)
|
|
seq_range = torch.arange(0, max_len, device=lengths.device)
|
|
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
|
|
|
return expaned_lengths >= lengths.unsqueeze(-1)
|
|
|
|
|
|
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
|
def subsequent_chunk_mask(
|
|
size: int,
|
|
chunk_size: int,
|
|
num_left_chunks: int = -1,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
this is for streaming encoder
|
|
Args:
|
|
size (int): size of mask
|
|
chunk_size (int): size of chunk
|
|
num_left_chunks (int): number of left chunks
|
|
<0: use full chunk
|
|
>=0: use num_left_chunks
|
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
Returns:
|
|
torch.Tensor: mask
|
|
Examples:
|
|
>>> subsequent_chunk_mask(4, 2)
|
|
[[1, 1, 0, 0],
|
|
[1, 1, 0, 0],
|
|
[1, 1, 1, 1],
|
|
[1, 1, 1, 1]]
|
|
"""
|
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
|
for i in range(size):
|
|
if num_left_chunks < 0:
|
|
start = 0
|
|
else:
|
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
|
ending = min((i // chunk_size + 1) * chunk_size, size)
|
|
ret[i, start:ending] = True
|
|
return ret
|
|
|
|
|
|
def l1_norm(x):
|
|
return torch.sum(torch.abs(x))
|
|
|
|
|
|
def l2_norm(x):
|
|
return torch.sum(torch.pow(x, 2))
|
|
|
|
|
|
def linf_norm(x):
|
|
return torch.max(torch.abs(x))
|
|
|
|
|
|
def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
|
|
"""
|
|
Compute the norms of the model's parameters.
|
|
|
|
:param model: a torch.nn.Module instance
|
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
|
|
:return: a dict mapping from parameter's name to its norm.
|
|
"""
|
|
with torch.no_grad():
|
|
norms = {}
|
|
for name, param in model.named_parameters():
|
|
if norm == "l1":
|
|
val = l1_norm(param)
|
|
elif norm == "l2":
|
|
val = l2_norm(param)
|
|
elif norm == "linf":
|
|
val = linf_norm(param)
|
|
else:
|
|
raise ValueError(f"Unknown norm type: {norm}")
|
|
norms[name] = val.item()
|
|
return norms
|
|
|
|
|
|
def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
|
|
"""
|
|
Compute the norms of the gradients for each of model's parameters.
|
|
|
|
:param model: a torch.nn.Module instance
|
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
|
|
:return: a dict mapping from parameter's name to its gradient's norm.
|
|
"""
|
|
with torch.no_grad():
|
|
norms = {}
|
|
for name, param in model.named_parameters():
|
|
if norm == "l1":
|
|
val = l1_norm(param.grad)
|
|
elif norm == "l2":
|
|
val = l2_norm(param.grad)
|
|
elif norm == "linf":
|
|
val = linf_norm(param.grad)
|
|
else:
|
|
raise ValueError(f"Unknown norm type: {norm}")
|
|
norms[name] = val.item()
|
|
return norms
|
|
|
|
|
|
def get_parameter_groups_with_lrs(
|
|
model: nn.Module,
|
|
lr: float,
|
|
include_names: bool = False,
|
|
freeze_modules: List[str] = [],
|
|
) -> List[dict]:
|
|
"""
|
|
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
|
named-parameters; we can, if needed, create a version without the names).
|
|
|
|
It provides a way to specify learning-rate scales inside the module, so that if
|
|
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
|
|
scale the LR of any parameters inside that module or its submodules. Note: you
|
|
can set module parameters outside the __init__ function, e.g.:
|
|
>>> a = nn.Linear(10, 10)
|
|
>>> a.lr_scale = 0.5
|
|
|
|
Returns: a list of dicts, of the following form:
|
|
if include_names == False:
|
|
[ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 },
|
|
{ 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 },
|
|
... ]
|
|
if include_names == true:
|
|
[ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 },
|
|
{ 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 },
|
|
... ]
|
|
|
|
"""
|
|
named_modules = list(model.named_modules())
|
|
|
|
# flat_lr_scale just contains the lr_scale explicitly specified
|
|
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
|
# to be multiplied for all prefix of the name of any given parameter.
|
|
flat_lr_scale = defaultdict(lambda: 1.0)
|
|
names = []
|
|
for name, m in model.named_modules():
|
|
names.append(name)
|
|
if hasattr(m, "lr_scale"):
|
|
flat_lr_scale[name] = m.lr_scale
|
|
|
|
# lr_to_parames is a dict from learning rate (floating point) to: if
|
|
# include_names == true, a list of (name, parameter) for that learning rate;
|
|
# otherwise a list of parameters for that learning rate.
|
|
lr_to_params = defaultdict(list)
|
|
|
|
for name, parameter in model.named_parameters():
|
|
split_name = name.split(".")
|
|
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
|
prefix = split_name[0]
|
|
if prefix == "module": # DDP
|
|
module_name = split_name[1]
|
|
if module_name in freeze_modules:
|
|
logging.info(f"Remove {name} from parameters")
|
|
continue
|
|
else:
|
|
if prefix in freeze_modules:
|
|
logging.info(f"Remove {name} from parameters")
|
|
continue
|
|
cur_lr = lr * flat_lr_scale[prefix]
|
|
if prefix != "":
|
|
cur_lr *= flat_lr_scale[""]
|
|
for part in split_name[1:]:
|
|
prefix = ".".join([prefix, part])
|
|
cur_lr *= flat_lr_scale[prefix]
|
|
lr_to_params[cur_lr].append((name, parameter) if include_names else parameter)
|
|
|
|
if include_names:
|
|
return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()]
|
|
else:
|
|
return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()]
|
|
|
|
|
|
def optim_step_and_measure_param_change(
|
|
model: nn.Module,
|
|
old_parameters: Dict[str, nn.parameter.Parameter],
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Measure the "relative change in parameters per minibatch."
|
|
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
|
|
and the L2 norm of the original parameter. It is given by the formula:
|
|
|
|
.. math::
|
|
|
|
\begin{aligned}
|
|
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
|
|
\end{aligned}
|
|
|
|
This function is supposed to be used as follows:
|
|
|
|
.. code-block:: python
|
|
|
|
old_parameters = {
|
|
n: p.detach().clone() for n, p in model.named_parameters()
|
|
}
|
|
|
|
optimizer.step()
|
|
|
|
deltas = optim_step_and_measure_param_change(old_parameters)
|
|
|
|
Args:
|
|
model: A torch.nn.Module instance.
|
|
old_parameters:
|
|
A Dict of named_parameters before optimizer.step().
|
|
|
|
Return:
|
|
A Dict containing the relative change for each parameter.
|
|
"""
|
|
relative_change = {}
|
|
with torch.no_grad():
|
|
for n, p_new in model.named_parameters():
|
|
p_orig = old_parameters[n]
|
|
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
|
relative_change[n] = delta.item()
|
|
return relative_change
|
|
|
|
|
|
def load_averaged_model(
|
|
model_dir: str,
|
|
model: torch.nn.Module,
|
|
epoch: int,
|
|
avg: int,
|
|
device: torch.device,
|
|
):
|
|
"""
|
|
Load a model which is the average of all checkpoints
|
|
|
|
:param model_dir: a str of the experiment directory
|
|
:param model: a torch.nn.Module instance
|
|
|
|
:param epoch: the last epoch to load from
|
|
:param avg: how many models to average from
|
|
:param device: move model to this device
|
|
|
|
:return: A model averaged
|
|
"""
|
|
|
|
# start cannot be negative
|
|
start = max(epoch - avg + 1, 0)
|
|
filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)]
|
|
|
|
logging.info(f"averaging {filenames}")
|
|
model.to(device)
|
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
|
|
return model
|
|
|
|
|
|
def text_to_pinyin(
|
|
txt: str, mode: str = "full_with_tone", errors: str = "default"
|
|
) -> List[str]:
|
|
"""
|
|
Convert a Chinese text (might contain some latin characters) to pinyin sequence.
|
|
|
|
Args:
|
|
txt:
|
|
The input Chinese text.
|
|
mode:
|
|
The style of the output pinyin, should be:
|
|
full_with_tone : zhōng guó
|
|
full_no_tone : zhong guo
|
|
partial_with_tone : zh ōng g uó
|
|
partial_no_tone : zh ong g uo
|
|
errors:
|
|
How to handle the characters (latin) that has no pinyin.
|
|
default : output the same as input.
|
|
split : split into single characters (i.e. alphabets)
|
|
|
|
Return:
|
|
Return a list of str.
|
|
|
|
Examples:
|
|
txt: 想吃KFC
|
|
output: ['xiǎng', 'chī', 'KFC'] # mode=full_with_tone; errors=default
|
|
output: ['xiǎng', 'chī', 'K', 'F', 'C'] # mode=full_with_tone; errors=split
|
|
output: ['xiang', 'chi', 'KFC'] # mode=full_no_tone; errors=default
|
|
output: ['xiang', 'chi', 'K', 'F', 'C'] # mode=full_no_tone; errors=split
|
|
output: ['x', 'iǎng', 'ch', 'ī', 'KFC'] # mode=partial_with_tone; errors=default
|
|
output: ['x', 'iang', 'ch', 'i', 'KFC'] # mode=partial_no_tone; errors=default
|
|
"""
|
|
|
|
assert mode in (
|
|
"full_with_tone",
|
|
"full_no_tone",
|
|
"partial_no_tone",
|
|
"partial_with_tone",
|
|
), mode
|
|
|
|
assert errors in ("default", "split"), errors
|
|
|
|
txt = txt.strip()
|
|
res = []
|
|
if "full" in mode:
|
|
if errors == "default":
|
|
py = pinyin(txt) if mode == "full_with_tone" else lazy_pinyin(txt)
|
|
else:
|
|
py = (
|
|
pinyin(txt, errors=lambda x: list(x))
|
|
if mode == "full_with_tone"
|
|
else lazy_pinyin(txt, errors=lambda x: list(x))
|
|
)
|
|
res = [x[0] for x in py] if mode == "full_with_tone" else py
|
|
else:
|
|
if errors == "default":
|
|
py = pinyin(txt) if mode == "partial_with_tone" else lazy_pinyin(txt)
|
|
else:
|
|
py = (
|
|
pinyin(txt, errors=lambda x: list(x))
|
|
if mode == "partial_with_tone"
|
|
else lazy_pinyin(txt, errors=lambda x: list(x))
|
|
)
|
|
py = [x[0] for x in py] if mode == "partial_with_tone" else py
|
|
for x in py:
|
|
initial = to_initials(x, strict=False)
|
|
final = (
|
|
to_finals(x, strict=False)
|
|
if mode == "partial_no_tone"
|
|
else to_finals_tone(x, strict=False)
|
|
)
|
|
if initial == "" and final == "":
|
|
res.append(x)
|
|
else:
|
|
if initial != "":
|
|
res.append(initial)
|
|
if final != "":
|
|
res.append(final)
|
|
return res
|
|
|
|
|
|
def tokenize_by_bpe_model(
|
|
sp: spm.SentencePieceProcessor,
|
|
txt: str,
|
|
) -> str:
|
|
"""
|
|
Tokenize text with bpe model. This function is from
|
|
https://github1s.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py#L322-L342.
|
|
Args:
|
|
sp: spm.SentencePieceProcessor.
|
|
txt: str
|
|
|
|
Return:
|
|
A new string which includes chars and bpes.
|
|
"""
|
|
tokens = []
|
|
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
|
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
|
pattern = re.compile(r"([\u4e00-\u9fff])")
|
|
# Example:
|
|
# txt = "你好 ITS'S OKAY 的"
|
|
# chars = ["你", "好", " ITS'S OKAY ", "的"]
|
|
chars = pattern.split(txt.upper())
|
|
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
|
for ch_or_w in mix_chars:
|
|
# ch_or_w is a single CJK character(i.e., "你"), do nothing.
|
|
if pattern.fullmatch(ch_or_w) is not None:
|
|
tokens.append(ch_or_w)
|
|
# ch_or_w contains non-CJK characters(i.e., " IT'S OKAY "),
|
|
# encode ch_or_w using bpe_model.
|
|
else:
|
|
for p in sp.encode_as_pieces(ch_or_w):
|
|
tokens.append(p)
|
|
txt_with_bpe = "/".join(tokens)
|
|
|
|
return txt_with_bpe
|
|
|
|
|
|
def tokenize_by_CJK_char(line: str) -> str:
|
|
"""
|
|
Tokenize a line of text with CJK char.
|
|
|
|
Note: All return characters will be upper case.
|
|
|
|
Example:
|
|
input = "你好世界是 hello world 的中文"
|
|
output = "你 好 世 界 是 HELLO WORLD 的 中 文"
|
|
|
|
Args:
|
|
line:
|
|
The input text.
|
|
|
|
Return:
|
|
A new string tokenize by CJK char.
|
|
"""
|
|
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
|
|
pattern = re.compile(
|
|
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
|
|
)
|
|
chars = pattern.split(line.strip().upper())
|
|
return " ".join([w.strip() for w in chars if w.strip()])
|
|
|
|
|
|
def tokenize_by_ja_char(line: str) -> str:
|
|
"""
|
|
Tokenize a line of text with Japanese characters.
|
|
|
|
Note: All non-Japanese characters will be upper case.
|
|
|
|
Example:
|
|
input = "こんにちは世界は hello world の日本語"
|
|
output = "こ ん に ち は 世 界 は HELLO WORLD の 日 本 語"
|
|
|
|
Args:
|
|
line:
|
|
The input text.
|
|
|
|
Return:
|
|
A new string tokenized by Japanese characters.
|
|
"""
|
|
pattern = re.compile(r"([\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF])")
|
|
chars = pattern.split(line.strip())
|
|
return " ".join(
|
|
[w.strip().upper() if not pattern.match(w) else w for w in chars if w.strip()]
|
|
)
|
|
|
|
|
|
def display_and_save_batch(
|
|
batch: dict,
|
|
params: AttributeDict,
|
|
sp: spm.SentencePieceProcessor,
|
|
) -> None:
|
|
"""Display the batch statistics and save the batch into disk.
|
|
|
|
Args:
|
|
batch:
|
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
|
for the content in it.
|
|
params:
|
|
Parameters for training. See :func:`get_params`.
|
|
sp:
|
|
The BPE model.
|
|
"""
|
|
from lhotse.utils import uuid4
|
|
|
|
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
|
logging.info(f"Saving batch to {filename}")
|
|
torch.save(batch, filename)
|
|
|
|
supervisions = batch["supervisions"]
|
|
features = batch["inputs"]
|
|
|
|
logging.info(f"features shape: {features.shape}")
|
|
|
|
y = sp.encode(supervisions["text"], out_type=int)
|
|
num_tokens = sum(len(i) for i in y)
|
|
logging.info(f"num tokens: {num_tokens}")
|
|
|
|
|
|
def convert_timestamp(
|
|
frames: List[int],
|
|
subsampling_factor: int,
|
|
frame_shift_ms: float = 10,
|
|
) -> List[float]:
|
|
"""Convert frame numbers to time (in seconds) given subsampling factor
|
|
and frame shift (in milliseconds).
|
|
|
|
Args:
|
|
frames:
|
|
A list of frame numbers after subsampling.
|
|
subsampling_factor:
|
|
The subsampling factor of the model.
|
|
frame_shift_ms:
|
|
Frame shift in milliseconds between two contiguous frames.
|
|
Return:
|
|
Return the time in seconds corresponding to each given frame.
|
|
"""
|
|
frame_shift = frame_shift_ms / 1000.0
|
|
time = []
|
|
for f in frames:
|
|
time.append(round(f * subsampling_factor * frame_shift, ndigits=3))
|
|
|
|
return time
|
|
|
|
|
|
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
|
"""
|
|
Parse timestamp of each word.
|
|
|
|
Args:
|
|
tokens:
|
|
List of tokens.
|
|
timestamp:
|
|
List of timestamp of each token.
|
|
|
|
Returns:
|
|
List of timestamp of each word.
|
|
"""
|
|
start_token = b"\xe2\x96\x81".decode() # '_'
|
|
assert len(tokens) == len(timestamp), (len(tokens), len(timestamp))
|
|
ans = []
|
|
for i in range(len(tokens)):
|
|
flag = False
|
|
if i == 0 or tokens[i].startswith(start_token):
|
|
flag = True
|
|
if len(tokens[i]) == 1 and tokens[i].startswith(start_token):
|
|
# tokens[i] == start_token
|
|
if i == len(tokens) - 1:
|
|
# it is the last token
|
|
flag = False
|
|
elif tokens[i + 1].startswith(start_token):
|
|
# the next token also starts with start_token
|
|
flag = False
|
|
if flag:
|
|
ans.append(timestamp[i])
|
|
return ans
|
|
|
|
|
|
def parse_hyp_and_timestamp(
|
|
res: DecodingResults,
|
|
subsampling_factor: int,
|
|
frame_shift_ms: float = 10,
|
|
sp: Optional[spm.SentencePieceProcessor] = None,
|
|
word_table: Optional[k2.SymbolTable] = None,
|
|
) -> Tuple[List[List[str]], List[List[float]]]:
|
|
"""Parse hypothesis and timestamp.
|
|
|
|
Args:
|
|
res:
|
|
A DecodingResults object.
|
|
subsampling_factor:
|
|
The integer subsampling factor.
|
|
frame_shift_ms:
|
|
The float frame shift used for feature extraction.
|
|
sp:
|
|
The BPE model.
|
|
word_table:
|
|
The word symbol table.
|
|
|
|
Returns:
|
|
Return a list of hypothesis and timestamp.
|
|
"""
|
|
hyps = []
|
|
timestamps = []
|
|
|
|
N = len(res.hyps)
|
|
assert len(res.timestamps) == N, (len(res.timestamps), N)
|
|
use_word_table = False
|
|
if word_table is not None:
|
|
assert sp is None
|
|
use_word_table = True
|
|
else:
|
|
assert sp is not None and word_table is None
|
|
|
|
for i in range(N):
|
|
time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
|
|
if use_word_table:
|
|
words = [word_table[i] for i in res.hyps[i]]
|
|
else:
|
|
tokens = sp.id_to_piece(res.hyps[i])
|
|
words = sp.decode_pieces(tokens).split()
|
|
time = parse_timestamp(tokens, time)
|
|
assert len(time) == len(words), (len(time), len(words))
|
|
|
|
hyps.append(words)
|
|
timestamps.append(time)
|
|
|
|
return hyps, timestamps
|
|
|
|
|
|
# `is_module_available` is copied from
|
|
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
|
def is_module_available(*modules: str) -> bool:
|
|
r"""Returns if a top-level module with :attr:`name` exists *without**
|
|
importing it. This is generally safer than try-catch block around a
|
|
`import X`.
|
|
|
|
Note: "borrowed" from torchaudio:
|
|
"""
|
|
import importlib
|
|
|
|
return all(importlib.util.find_spec(m) is not None for m in modules)
|
|
|
|
|
|
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
|
"""For the uneven-sized batch, the total duration after padding would possibly
|
|
cause OOM. Hence, for each batch, which is sorted in descending order by length,
|
|
we simply drop the last few shortest samples, so that the retained total frames
|
|
(after padding) would not exceed the given allow_max_frames.
|
|
|
|
Args:
|
|
batch:
|
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
|
for the content in it.
|
|
allowed_max_frames:
|
|
The allowed max number of frames in batch.
|
|
"""
|
|
features = batch["inputs"]
|
|
supervisions = batch["supervisions"]
|
|
|
|
N, T, _ = features.size()
|
|
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
|
|
kept_num_utt = allowed_max_frames // T
|
|
|
|
if kept_num_utt >= N or kept_num_utt == 0:
|
|
return batch
|
|
|
|
# Note: we assume the samples in batch is sorted descendingly by length
|
|
logging.info(
|
|
f"Filtering uneven-sized batch, original batch size is {N}, "
|
|
f"retained batch size is {kept_num_utt}."
|
|
)
|
|
batch["inputs"] = features[:kept_num_utt]
|
|
for k, v in supervisions.items():
|
|
assert len(v) == N, (len(v), N)
|
|
batch["supervisions"][k] = v[:kept_num_utt]
|
|
|
|
return batch
|
|
|
|
|
|
def parse_bpe_start_end_pairs(
|
|
tokens: List[str], is_first_token: List[bool]
|
|
) -> List[Tuple[int, int]]:
|
|
"""Parse pairs of start and end frame indexes for each word.
|
|
|
|
Args:
|
|
tokens:
|
|
List of BPE tokens.
|
|
is_first_token:
|
|
List of bool values, which indicates whether it is the first token,
|
|
i.e., not repeat or blank.
|
|
|
|
Returns:
|
|
List of (start-frame-index, end-frame-index) pairs for each word.
|
|
"""
|
|
assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token))
|
|
|
|
start_token = b"\xe2\x96\x81".decode() # '_'
|
|
blank_token = "<blk>"
|
|
|
|
non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token]
|
|
num_non_blank = len(non_blank_idx)
|
|
|
|
pairs = []
|
|
start = -1
|
|
end = -1
|
|
for j in range(num_non_blank):
|
|
# The index in all frames
|
|
i = non_blank_idx[j]
|
|
|
|
found_start = False
|
|
if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)):
|
|
found_start = True
|
|
if tokens[i] == start_token:
|
|
if j == num_non_blank - 1:
|
|
# It is the last non-blank token
|
|
found_start = False
|
|
elif is_first_token[non_blank_idx[j + 1]] and tokens[
|
|
non_blank_idx[j + 1]
|
|
].startswith(start_token):
|
|
# The next not-blank token is a first-token and also starts with start_token
|
|
found_start = False
|
|
if found_start:
|
|
start = i
|
|
|
|
if start != -1:
|
|
found_end = False
|
|
if j == num_non_blank - 1:
|
|
# It is the last non-blank token
|
|
found_end = True
|
|
elif is_first_token[non_blank_idx[j + 1]] and tokens[
|
|
non_blank_idx[j + 1]
|
|
].startswith(start_token):
|
|
# The next not-blank token is a first-token and also starts with start_token
|
|
found_end = True
|
|
if found_end:
|
|
end = i
|
|
|
|
if start != -1 and end != -1:
|
|
if not all([tokens[t] == start_token for t in range(start, end + 1)]):
|
|
# except the case of all start_token
|
|
pairs.append((start, end))
|
|
# Reset start and end
|
|
start = -1
|
|
end = -1
|
|
|
|
return pairs
|
|
|
|
|
|
def parse_bpe_timestamps_and_texts(
|
|
best_paths: k2.Fsa, sp: spm.SentencePieceProcessor
|
|
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
|
|
"""Parse timestamps (frame indexes) and texts.
|
|
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful). Its attributes `labels` and `aux_labels`
|
|
are both BPE tokens.
|
|
sp:
|
|
The BPE model.
|
|
|
|
Returns:
|
|
utt_index_pairs:
|
|
A list of pair list. utt_index_pairs[i] is a list of
|
|
(start-frame-index, end-frame-index) pairs for each word in
|
|
utterance-i.
|
|
utt_words:
|
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
|
"""
|
|
shape = best_paths.arcs.shape().remove_axis(1)
|
|
|
|
# labels: [utt][arcs]
|
|
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
|
|
# remove -1's.
|
|
labels = labels.remove_values_eq(-1)
|
|
labels = labels.tolist()
|
|
|
|
# aux_labels: [utt][arcs]
|
|
aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous())
|
|
|
|
# remove -1's.
|
|
all_aux_labels = aux_labels.remove_values_eq(-1)
|
|
# len(all_aux_labels[i]) is equal to the number of frames
|
|
all_aux_labels = all_aux_labels.tolist()
|
|
|
|
# remove 0's and -1's.
|
|
out_aux_labels = aux_labels.remove_values_leq(0)
|
|
# len(out_aux_labels[i]) is equal to the number of output BPE tokens
|
|
out_aux_labels = out_aux_labels.tolist()
|
|
|
|
utt_index_pairs = []
|
|
utt_words = []
|
|
for i in range(len(labels)):
|
|
tokens = sp.id_to_piece(labels[i])
|
|
words = sp.decode(out_aux_labels[i]).split()
|
|
|
|
# Indicates whether it is the first token, i.e., not-repeat and not-blank.
|
|
is_first_token = [a != 0 for a in all_aux_labels[i]]
|
|
index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token)
|
|
assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens)
|
|
utt_index_pairs.append(index_pairs)
|
|
utt_words.append(words)
|
|
|
|
return utt_index_pairs, utt_words
|
|
|
|
|
|
def parse_timestamps_and_texts(
|
|
best_paths: k2.Fsa, word_table: k2.SymbolTable
|
|
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
|
|
"""Parse timestamps (frame indexes) and texts.
|
|
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful). Attribute `labels` is the prediction unit,
|
|
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
|
|
word_table:
|
|
The word symbol table.
|
|
|
|
Returns:
|
|
utt_index_pairs:
|
|
A list of pair list. utt_index_pairs[i] is a list of
|
|
(start-frame-index, end-frame-index) pairs for each word in
|
|
utterance-i.
|
|
utt_words:
|
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
|
"""
|
|
# [utt][words]
|
|
word_ids = get_texts(best_paths)
|
|
|
|
shape = best_paths.arcs.shape().remove_axis(1)
|
|
|
|
# labels: [utt][arcs]
|
|
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
|
|
# remove -1's.
|
|
labels = labels.remove_values_eq(-1)
|
|
labels = labels.tolist()
|
|
|
|
# aux_labels: [utt][arcs]
|
|
aux_shape = shape.compose(best_paths.aux_labels.shape)
|
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous())
|
|
aux_labels = aux_labels.tolist()
|
|
|
|
utt_index_pairs = []
|
|
utt_words = []
|
|
for i, (label, aux_label) in enumerate(zip(labels, aux_labels)):
|
|
num_arcs = len(label)
|
|
# The last arc of aux_label is the arc entering the final state
|
|
assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label))
|
|
|
|
index_pairs = []
|
|
start = -1
|
|
end = -1
|
|
for arc in range(num_arcs):
|
|
# len(aux_label[arc]) is 0 or 1
|
|
if label[arc] != 0 and len(aux_label[arc]) != 0:
|
|
if start != -1 and end != -1:
|
|
index_pairs.append((start, end))
|
|
start = arc
|
|
if label[arc] != 0:
|
|
end = arc
|
|
if start != -1 and end != -1:
|
|
index_pairs.append((start, end))
|
|
|
|
words = [word_table[w] for w in word_ids[i]]
|
|
assert len(index_pairs) == len(words), (len(index_pairs), len(words))
|
|
|
|
utt_index_pairs.append(index_pairs)
|
|
utt_words.append(words)
|
|
|
|
return utt_index_pairs, utt_words
|
|
|
|
|
|
def parse_fsa_timestamps_and_texts(
|
|
best_paths: k2.Fsa,
|
|
sp: Optional[spm.SentencePieceProcessor] = None,
|
|
word_table: Optional[k2.SymbolTable] = None,
|
|
subsampling_factor: int = 4,
|
|
frame_shift_ms: float = 10,
|
|
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
|
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
|
|
Currently it supports two cases:
|
|
(1) ctc-decoding, the attributes `labels` and `aux_labels`
|
|
are both BPE tokens. In this case, sp should be provided.
|
|
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,
|
|
e.g., phone or BPE tokens; attribute `aux_labels` is the word index.
|
|
In this case, word_table should be provided.
|
|
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful).
|
|
sp:
|
|
The BPE model.
|
|
word_table:
|
|
The word symbol table.
|
|
subsampling_factor:
|
|
The subsampling factor of the model.
|
|
frame_shift_ms:
|
|
Frame shift in milliseconds between two contiguous frames.
|
|
|
|
Returns:
|
|
utt_time_pairs:
|
|
A list of pair list. utt_time_pairs[i] is a list of
|
|
(start-time, end-time) pairs for each word in
|
|
utterance-i.
|
|
utt_words:
|
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
|
"""
|
|
if sp is not None:
|
|
assert word_table is None, "word_table is not needed if sp is provided."
|
|
utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(
|
|
best_paths=best_paths, sp=sp
|
|
)
|
|
elif word_table is not None:
|
|
assert sp is None, "sp is not needed if word_table is provided."
|
|
utt_index_pairs, utt_words = parse_timestamps_and_texts(
|
|
best_paths=best_paths, word_table=word_table
|
|
)
|
|
else:
|
|
raise ValueError("Either sp or word_table should be provided.")
|
|
|
|
utt_time_pairs = []
|
|
for utt in utt_index_pairs:
|
|
start = convert_timestamp(
|
|
frames=[i[0] for i in utt],
|
|
subsampling_factor=subsampling_factor,
|
|
frame_shift_ms=frame_shift_ms,
|
|
)
|
|
end = convert_timestamp(
|
|
# The duration in frames is (end_frame_index - start_frame_index + 1)
|
|
frames=[i[1] + 1 for i in utt],
|
|
subsampling_factor=subsampling_factor,
|
|
frame_shift_ms=frame_shift_ms,
|
|
)
|
|
utt_time_pairs.append(list(zip(start, end)))
|
|
|
|
return utt_time_pairs, utt_words
|
|
|
|
|
|
# Copied from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
|
|
def is_cjk(character):
|
|
"""
|
|
Python port of Moses' code to check for CJK character.
|
|
|
|
>>> is_cjk(u'\u33fe')
|
|
True
|
|
>>> is_cjk(u'\uFE5F')
|
|
False
|
|
|
|
:param character: The character that needs to be checked.
|
|
:type character: char
|
|
:return: bool
|
|
"""
|
|
return any(
|
|
[
|
|
start <= ord(character) <= end
|
|
for start, end in [
|
|
(4352, 4607),
|
|
(11904, 42191),
|
|
(43072, 43135),
|
|
(44032, 55215),
|
|
(63744, 64255),
|
|
(65072, 65103),
|
|
(65381, 65500),
|
|
(131072, 196607),
|
|
]
|
|
]
|
|
)
|
|
|
|
|
|
def symlink_or_copy(exp_dir: Path, src: str, dst: str):
|
|
"""
|
|
In the experiment directory, create a symlink pointing to src named dst.
|
|
If symlink creation fails (Windows?), fall back to copyfile."""
|
|
|
|
dir_fd = os.open(exp_dir, os.O_RDONLY)
|
|
try:
|
|
os.remove(dst, dir_fd=dir_fd)
|
|
except FileNotFoundError:
|
|
pass
|
|
try:
|
|
os.symlink(src=src, dst=dst, dir_fd=dir_fd)
|
|
except OSError:
|
|
copyfile(src=exp_dir / src, dst=exp_dir / dst)
|
|
os.close(dir_fd)
|
|
|
|
|
|
def num_tokens(
|
|
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
|
) -> int:
|
|
"""Return the number of tokens excluding those from
|
|
disambiguation symbols.
|
|
|
|
Caution:
|
|
0 is not a token ID so it is excluded from the return value.
|
|
"""
|
|
symbols = token_table.symbols
|
|
ans = []
|
|
for s in symbols:
|
|
if not disambig_pattern.match(s):
|
|
ans.append(token_table[s])
|
|
num_tokens = len(ans)
|
|
if 0 in ans:
|
|
num_tokens -= 1
|
|
return num_tokens
|
|
|
|
|
|
# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
|
|
def time_warp(
|
|
features: torch.Tensor,
|
|
p: float = 0.9,
|
|
time_warp_factor: Optional[int] = 80,
|
|
supervision_segments: Optional[torch.Tensor] = None,
|
|
):
|
|
"""Apply time warping on a batch of features"""
|
|
if time_warp_factor is None or time_warp_factor < 1:
|
|
return features
|
|
assert (
|
|
len(features.shape) == 3
|
|
), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}"
|
|
features = features.clone()
|
|
if supervision_segments is None:
|
|
# No supervisions - apply spec augment to full feature matrices.
|
|
for sequence_idx in range(features.size(0)):
|
|
if random.random() > p:
|
|
# Randomly choose whether this transform is applied
|
|
continue
|
|
features[sequence_idx] = time_warp_impl(
|
|
features[sequence_idx], factor=time_warp_factor
|
|
)
|
|
else:
|
|
# Supervisions provided - we will apply time warping only on the supervised areas.
|
|
for sequence_idx, start_frame, num_frames in supervision_segments:
|
|
if random.random() > p:
|
|
# Randomly choose whether this transform is applied
|
|
continue
|
|
end_frame = start_frame + num_frames
|
|
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
|
|
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
|
|
)
|
|
|
|
return features
|