mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
433 lines
14 KiB
Python
433 lines
14 KiB
Python
import argparse
|
|
import collections
|
|
import json
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import random
|
|
import re
|
|
import subprocess
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
|
from tqdm import tqdm
|
|
import kaldialign
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import numpy as np
|
|
Pathlike = Union[str, Path]
|
|
|
|
|
|
def get_world_size():
|
|
if "WORLD_SIZE" in os.environ:
|
|
return int(os.environ["WORLD_SIZE"])
|
|
if dist.is_available() and dist.is_initialized():
|
|
return dist.get_world_size()
|
|
else:
|
|
return 1
|
|
|
|
|
|
def get_rank():
|
|
if "RANK" in os.environ:
|
|
return int(os.environ["RANK"])
|
|
elif dist.is_available() and dist.is_initialized():
|
|
return dist.get_rank()
|
|
else:
|
|
return 0
|
|
|
|
|
|
def get_local_rank():
|
|
if "LOCAL_RANK" in os.environ:
|
|
return int(os.environ["LOCAL_RANK"])
|
|
elif dist.is_available() and dist.is_initialized():
|
|
return dist.get_local_rank()
|
|
else:
|
|
return 0
|
|
|
|
|
|
def str2bool(v):
|
|
"""Used in argparse.ArgumentParser.add_argument to indicate
|
|
that a type is a bool type and user can enter
|
|
|
|
- yes, true, t, y, 1, to represent True
|
|
- no, false, f, n, 0, to represent False
|
|
|
|
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
|
"""
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ("yes", "true", "t", "y", "1"):
|
|
return True
|
|
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
|
|
|
|
|
class AttributeDict(dict):
|
|
def __getattr__(self, key):
|
|
if key in self:
|
|
return self[key]
|
|
raise AttributeError(f"No such attribute '{key}'")
|
|
|
|
def __setattr__(self, key, value):
|
|
self[key] = value
|
|
|
|
def __delattr__(self, key):
|
|
if key in self:
|
|
del self[key]
|
|
return
|
|
raise AttributeError(f"No such attribute '{key}'")
|
|
|
|
def __str__(self, indent: int = 2):
|
|
tmp = {}
|
|
for k, v in self.items():
|
|
# PosixPath is ont JSON serializable
|
|
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
|
|
v = str(v)
|
|
tmp[k] = v
|
|
return json.dumps(tmp, indent=indent, sort_keys=True)
|
|
|
|
|
|
def setup_logger(
|
|
log_filename: Pathlike,
|
|
log_level: str = "info",
|
|
use_console: bool = True,
|
|
) -> None:
|
|
"""Setup log level.
|
|
|
|
Args:
|
|
log_filename:
|
|
The filename to save the log.
|
|
log_level:
|
|
The log level to use, e.g., "debug", "info", "warning", "error",
|
|
"critical"
|
|
use_console:
|
|
True to also print logs to console.
|
|
"""
|
|
now = datetime.now()
|
|
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
|
if dist.is_available() and dist.is_initialized():
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
|
log_filename = f"{log_filename}-{date_time}-{rank}"
|
|
else:
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
log_filename = f"{log_filename}-{date_time}"
|
|
|
|
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
|
|
|
level = logging.ERROR
|
|
if log_level == "debug":
|
|
level = logging.DEBUG
|
|
elif log_level == "info":
|
|
level = logging.INFO
|
|
elif log_level == "warning":
|
|
level = logging.WARNING
|
|
elif log_level == "critical":
|
|
level = logging.CRITICAL
|
|
|
|
logging.basicConfig(
|
|
filename=log_filename,
|
|
format=formatter,
|
|
level=level,
|
|
filemode="w",
|
|
force=True,
|
|
)
|
|
if use_console:
|
|
console = logging.StreamHandler()
|
|
console.setLevel(level)
|
|
console.setFormatter(logging.Formatter(formatter))
|
|
logging.getLogger("").addHandler(console)
|
|
|
|
|
|
class MetricsTracker(collections.defaultdict):
|
|
def __init__(self):
|
|
# Passing the type 'int' to the base-class constructor
|
|
# makes undefined items default to int() which is zero.
|
|
# This class will play a role as metrics tracker.
|
|
# It can record many metrics, including but not limited to loss.
|
|
super(MetricsTracker, self).__init__(int)
|
|
|
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v
|
|
for k, v in other.items():
|
|
if v - v == 0:
|
|
ans[k] = ans[k] + v
|
|
return ans
|
|
|
|
def __mul__(self, alpha: float) -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v * alpha
|
|
return ans
|
|
|
|
def __str__(self) -> str:
|
|
ans_frames = ""
|
|
ans_utterances = ""
|
|
for k, v in self.norm_items():
|
|
norm_value = "%.4g" % v
|
|
if "utt_" not in k:
|
|
ans_frames += str(k) + "=" + str(norm_value) + ", "
|
|
else:
|
|
ans_utterances += str(k) + "=" + str(norm_value)
|
|
if k == "utt_duration":
|
|
ans_utterances += " frames, "
|
|
elif k == "utt_pad_proportion":
|
|
ans_utterances += ", "
|
|
else:
|
|
raise ValueError(f"Unexpected key: {k}")
|
|
frames = "%.2f" % self["frames"]
|
|
ans_frames += "over " + str(frames) + " frames. "
|
|
if ans_utterances != "":
|
|
utterances = "%.2f" % self["utterances"]
|
|
ans_utterances += "over " + str(utterances) + " utterances."
|
|
|
|
return ans_frames + ans_utterances
|
|
|
|
def norm_items(self) -> List[Tuple[str, float]]:
|
|
"""
|
|
Returns a list of pairs, like:
|
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
|
"""
|
|
num_frames = self["frames"] if "frames" in self else 1
|
|
num_utterances = self["utterances"] if "utterances" in self else 1
|
|
ans = []
|
|
for k, v in self.items():
|
|
if k == "frames" or k == "utterances":
|
|
continue
|
|
norm_value = (
|
|
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
|
|
)
|
|
ans.append((k, norm_value))
|
|
return ans
|
|
|
|
def reduce(self, device):
|
|
"""
|
|
Reduce using torch.distributed, which I believe ensures that
|
|
all processes get the total.
|
|
"""
|
|
keys = sorted(self.keys())
|
|
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
for k, v in zip(keys, s.cpu().tolist()):
|
|
self[k] = v
|
|
|
|
def write_summary(
|
|
self,
|
|
tb_writer: SummaryWriter,
|
|
prefix: str,
|
|
batch_idx: int,
|
|
) -> None:
|
|
"""Add logging information to a TensorBoard writer.
|
|
|
|
Args:
|
|
tb_writer: a TensorBoard writer
|
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
|
or "train/current_"
|
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
|
"""
|
|
for k, v in self.norm_items():
|
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
|
|
|
|
|
def store_transcripts(
|
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
|
) -> None:
|
|
"""Save predicted results and reference transcripts to a file.
|
|
|
|
Args:
|
|
filename:
|
|
File to save the results to.
|
|
texts:
|
|
An iterable of tuples. The first element is the cur_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
|
strings.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
with open(filename, "w", encoding="utf8") as f:
|
|
for cut_id, ref, hyp in texts:
|
|
if char_level:
|
|
ref = list("".join(ref))
|
|
hyp = list("".join(hyp))
|
|
print(f"{cut_id}:\tref={ref}", file=f)
|
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
|
|
|
|
|
def write_error_stats(
|
|
f: TextIO,
|
|
test_set_name: str,
|
|
results: List[Tuple[str, str]],
|
|
enable_log: bool = True,
|
|
compute_CER: bool = False,
|
|
sclite_mode: bool = False,
|
|
) -> float:
|
|
"""Write statistics based on predicted results and reference transcripts.
|
|
|
|
It will write the following to the given file:
|
|
|
|
- WER
|
|
- number of insertions, deletions, substitutions, corrects and total
|
|
reference words. For example::
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
|
reference words (2337 correct)
|
|
|
|
- The difference between the reference transcript and predicted result.
|
|
An instance is given below::
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
|
|
|
The above example shows that the reference word is `EDISON`,
|
|
but it is predicted to `ADDISON` (a substitution error).
|
|
|
|
Another example is::
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK
|
|
|
|
The reference word `SIR` is missing in the predicted
|
|
results (a deletion error).
|
|
results:
|
|
An iterable of tuples. The first element is the cut_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
enable_log:
|
|
If True, also print detailed WER to the console.
|
|
Otherwise, it is written only to the given file.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
ins: Dict[str, int] = defaultdict(int)
|
|
dels: Dict[str, int] = defaultdict(int)
|
|
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
|
|
if compute_CER:
|
|
for i, res in enumerate(results):
|
|
cut_id, ref, hyp = res
|
|
ref = list("".join(ref))
|
|
hyp = list("".join(hyp))
|
|
results[i] = (cut_id, ref, hyp)
|
|
|
|
for cut_id, ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR:
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
elif hyp_word == ERR:
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
elif hyp_word != ref_word:
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
ref_len = sum([len(r) for _, r, _ in results])
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
|
|
if enable_log:
|
|
logging.info(
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
f"{del_errs} del, {sub_errs} sub ]"
|
|
)
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
print(
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
|
f"words ({num_corr} correct)",
|
|
file=f,
|
|
)
|
|
print(
|
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
for cut_id, ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
combine_successive_errors = True
|
|
if combine_successive_errors:
|
|
ali = [[[x], [y]] for x, y in ali]
|
|
for i in range(len(ali) - 1):
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
ali[i] = [[], []]
|
|
ali = [
|
|
[
|
|
list(filter(lambda a: a != ERR, x)),
|
|
list(filter(lambda a: a != ERR, y)),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
ali = list(filter(lambda x: x != [[], []], ali))
|
|
ali = [
|
|
[
|
|
ERR if x == [] else " ".join(x),
|
|
ERR if y == [] else " ".join(y),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
|
|
print(
|
|
f"{cut_id}:\t"
|
|
+ " ".join(
|
|
(
|
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
|
for ref_word, hyp_word in ali
|
|
)
|
|
),
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
|
print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("DELETIONS: count ref", file=f)
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
print(f"{count} {ref}", file=f)
|
|
|
|
print("", file=f)
|
|
print("INSERTIONS: count hyp", file=f)
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
print(f"{count} {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
|
for _, word, counts in sorted(
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
|
):
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
|
ref_count = corr + ref_sub + dels
|
|
hyp_count = corr + hyp_sub + ins
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
return float(tot_err_rate) |