mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Delete utils.py
This commit is contained in:
parent
f837818af3
commit
988d7af85c
495
icefall/utils.py
495
icefall/utils.py
@ -1,495 +0,0 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import collections
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import kaldialign
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_executor():
|
||||
# We'll either return a process pool or a distributed worker pool.
|
||||
# Note that this has to be a context manager because we might use multiple
|
||||
# context manager ("with" clauses) inside, and this way everything will
|
||||
# free up the resources at the right time.
|
||||
try:
|
||||
# If this is executed on the CLSP grid, we will try to use the
|
||||
# Grid Engine to distribute the tasks.
|
||||
# Other clusters can also benefit from that, provided a
|
||||
# cluster-specific wrapper.
|
||||
# (see https://github.com/pzelasko/plz for reference)
|
||||
#
|
||||
# The following must be installed:
|
||||
# $ pip install dask distributed
|
||||
# $ pip install git+https://github.com/pzelasko/plz
|
||||
name = subprocess.check_output("hostname -f", shell=True, text=True)
|
||||
if name.strip().endswith(".clsp.jhu.edu"):
|
||||
import plz
|
||||
from distributed import Client
|
||||
|
||||
with plz.setup_cluster() as cluster:
|
||||
cluster.scale(80)
|
||||
yield Client(cluster)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
# No need to return anything - compute_and_store_features
|
||||
# will just instantiate the pool itself.
|
||||
yield None
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||
that a type is a bool type and user can enter
|
||||
|
||||
- yes, true, t, y, 1, to represent True
|
||||
- no, false, f, n, 0, to represent False
|
||||
|
||||
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def setup_logger(
|
||||
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
|
||||
) -> None:
|
||||
"""Setup log level.
|
||||
|
||||
Args:
|
||||
log_filename:
|
||||
The filename to save the log.
|
||||
log_level:
|
||||
The log level to use, e.g., "debug", "info", "warning", "error",
|
||||
"critical"
|
||||
"""
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||
log_filename = f"{log_filename}-{date_time}-{rank}"
|
||||
else:
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
log_filename = f"{log_filename}-{date_time}"
|
||||
|
||||
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
||||
|
||||
level = logging.ERROR
|
||||
if log_level == "debug":
|
||||
level = logging.DEBUG
|
||||
elif log_level == "info":
|
||||
level = logging.INFO
|
||||
elif log_level == "warning":
|
||||
level = logging.WARNING
|
||||
elif log_level == "critical":
|
||||
level = logging.CRITICAL
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename, format=formatter, level=level, filemode="w"
|
||||
)
|
||||
if use_console:
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(level)
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""
|
||||
TODO:
|
||||
"""
|
||||
return {
|
||||
"k2-git-sha1": None,
|
||||
"k2-version": None,
|
||||
"lhotse-version": None,
|
||||
"torch-version": None,
|
||||
"icefall-sha1": None,
|
||||
"icefall-version": None,
|
||||
}
|
||||
|
||||
|
||||
class AttributeDict(dict):
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
return self[key]
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key in self:
|
||||
del self[key]
|
||||
return
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
|
||||
def encode_supervisions(
|
||||
supervisions: dict, subsampling_factor: int
|
||||
) -> Tuple[torch.Tensor, List[str]]:
|
||||
"""
|
||||
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,
|
||||
and a list of transcription strings.
|
||||
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
start frames [1] and num frames [2].
|
||||
|
||||
The batch items might become re-ordered during this operation -- the
|
||||
returned tensor and list of strings are guaranteed to be consistent with
|
||||
each other.
|
||||
"""
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // subsampling_factor,
|
||||
supervisions["num_frames"] // subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
||||
supervision_segments = supervision_segments[indices]
|
||||
texts = supervisions["text"]
|
||||
texts = [texts[idx] for idx in indices]
|
||||
|
||||
return supervision_segments, texts
|
||||
|
||||
|
||||
def get_texts(
|
||||
best_paths: k2.Fsa, return_ragged: bool = False
|
||||
) -> Union[List[List[int]], k2.RaggedTensor]:
|
||||
"""Extract the texts (as word IDs) from the best-path FSAs.
|
||||
Args:
|
||||
best_paths:
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
return_ragged:
|
||||
True to return a ragged tensor with two axes [utt][word_id].
|
||||
False to return a list-of-list word IDs.
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the label sequences we
|
||||
decoded.
|
||||
"""
|
||||
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||
# remove 0's and -1's.
|
||||
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
||||
# TODO: change arcs.shape() to arcs.shape
|
||||
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
||||
|
||||
# remove the states and arcs axes.
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
||||
else:
|
||||
# remove axis corresponding to states.
|
||||
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
||||
# remove 0's and -1's.
|
||||
aux_labels = aux_labels.remove_values_leq(0)
|
||||
|
||||
assert aux_labels.num_axes == 2
|
||||
if return_ragged:
|
||||
return aux_labels
|
||||
else:
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the reference transcript
|
||||
while the second element is the predicted result.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
for ref, hyp in texts:
|
||||
print(f"ref={ref}", file=f)
|
||||
print(f"hyp={hyp}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted results.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`, but it is
|
||||
predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the reference transcript
|
||||
while the second element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
for ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
ref_len = sum([len(r) for r, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
" ".join(
|
||||
(
|
||||
ref_word
|
||||
if ref_word == hyp_word
|
||||
else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted(
|
||||
[(v, k) for k, v in subs.items()], reverse=True
|
||||
):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print(
|
||||
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||
)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
class LossRecord(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
# makes undefined items default to int() which is zero.
|
||||
super(LossRecord, self).__init__(int)
|
||||
|
||||
def __add__(self, other: 'LossRecord') -> 'LossRecord':
|
||||
ans = LossRecord()
|
||||
for k, v in self.items():
|
||||
ans[k] = v
|
||||
for k, v in other.items():
|
||||
ans[k] = ans[k] + v
|
||||
return ans
|
||||
|
||||
def __mul__(self, alpha: float) -> 'LossRecord':
|
||||
ans = LossRecord()
|
||||
for k, v in self.items():
|
||||
ans[k] = v * alpha
|
||||
return ans
|
||||
|
||||
def __str__(self) -> str:
|
||||
ans = ''
|
||||
for k, v in self.norm_items():
|
||||
norm_value = '%.4g' % v
|
||||
ans += (str(k) + '=' + str(norm_value) + ', ')
|
||||
frames = str(self['frames'])
|
||||
ans += 'over ' + frames + ' frames.'
|
||||
return ans
|
||||
|
||||
def norm_items(self) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Returns a list of pairs, like:
|
||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||
"""
|
||||
num_frames = self['frames'] if 'frames' in self else 1
|
||||
ans = []
|
||||
for k, v in self.items():
|
||||
if k != 'frames':
|
||||
norm_value = float(v) / num_frames
|
||||
ans.append((k, norm_value))
|
||||
return ans
|
||||
|
||||
def reduce(self, device):
|
||||
"""
|
||||
Reduce using torch.distributed, which I believe ensures that
|
||||
all processes get the total.
|
||||
"""
|
||||
keys = sorted(self.keys())
|
||||
s = torch.tensor([float(self[k]) for k in keys],
|
||||
device=device)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
for k, v in zip(keys, s.cpu().tolist()):
|
||||
self[k] = v
|
||||
|
||||
def write_summary(
|
||||
self,
|
||||
tb_writer: SummaryWriter,
|
||||
prefix: str,
|
||||
batch_idx: int,
|
||||
) -> None:
|
||||
"""Add logging information to a TensorBoard writer.
|
||||
|
||||
Args:
|
||||
tb_writer: a TensorBoard writer
|
||||
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||
or "train/current_"
|
||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||
"""
|
||||
for k, v in self.norm_items():
|
||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
Loading…
x
Reference in New Issue
Block a user