mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
799 lines
25 KiB
Python
799 lines
25 KiB
Python
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
|
# Mingshuang Luo)
|
|
#
|
|
# See ../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import argparse
|
|
import collections
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
|
|
|
import k2
|
|
import k2.version
|
|
import kaldialign
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
Pathlike = Union[str, Path]
|
|
|
|
|
|
@contextmanager
|
|
def get_executor():
|
|
# We'll either return a process pool or a distributed worker pool.
|
|
# Note that this has to be a context manager because we might use multiple
|
|
# context manager ("with" clauses) inside, and this way everything will
|
|
# free up the resources at the right time.
|
|
try:
|
|
# If this is executed on the CLSP grid, we will try to use the
|
|
# Grid Engine to distribute the tasks.
|
|
# Other clusters can also benefit from that, provided a
|
|
# cluster-specific wrapper.
|
|
# (see https://github.com/pzelasko/plz for reference)
|
|
#
|
|
# The following must be installed:
|
|
# $ pip install dask distributed
|
|
# $ pip install git+https://github.com/pzelasko/plz
|
|
name = subprocess.check_output("hostname -f", shell=True, text=True)
|
|
if name.strip().endswith(".clsp.jhu.edu"):
|
|
import plz
|
|
from distributed import Client
|
|
|
|
with plz.setup_cluster() as cluster:
|
|
cluster.scale(80)
|
|
yield Client(cluster)
|
|
return
|
|
except Exception:
|
|
pass
|
|
# No need to return anything - compute_and_store_features
|
|
# will just instantiate the pool itself.
|
|
yield None
|
|
|
|
|
|
def str2bool(v):
|
|
"""Used in argparse.ArgumentParser.add_argument to indicate
|
|
that a type is a bool type and user can enter
|
|
|
|
- yes, true, t, y, 1, to represent True
|
|
- no, false, f, n, 0, to represent False
|
|
|
|
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
|
"""
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ("yes", "true", "t", "y", "1"):
|
|
return True
|
|
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
|
|
|
|
|
def setup_logger(
|
|
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
|
|
) -> None:
|
|
"""Setup log level.
|
|
|
|
Args:
|
|
log_filename:
|
|
The filename to save the log.
|
|
log_level:
|
|
The log level to use, e.g., "debug", "info", "warning", "error",
|
|
"critical"
|
|
"""
|
|
now = datetime.now()
|
|
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
|
if dist.is_available() and dist.is_initialized():
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
|
log_filename = f"{log_filename}-{date_time}-{rank}"
|
|
else:
|
|
formatter = (
|
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
)
|
|
log_filename = f"{log_filename}-{date_time}"
|
|
|
|
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
|
|
|
level = logging.ERROR
|
|
if log_level == "debug":
|
|
level = logging.DEBUG
|
|
elif log_level == "info":
|
|
level = logging.INFO
|
|
elif log_level == "warning":
|
|
level = logging.WARNING
|
|
elif log_level == "critical":
|
|
level = logging.CRITICAL
|
|
|
|
logging.basicConfig(
|
|
filename=log_filename, format=formatter, level=level, filemode="w"
|
|
)
|
|
if use_console:
|
|
console = logging.StreamHandler()
|
|
console.setLevel(level)
|
|
console.setFormatter(logging.Formatter(formatter))
|
|
logging.getLogger("").addHandler(console)
|
|
|
|
|
|
class AttributeDict(dict):
|
|
def __getattr__(self, key):
|
|
if key in self:
|
|
return self[key]
|
|
raise AttributeError(f"No such attribute '{key}'")
|
|
|
|
def __setattr__(self, key, value):
|
|
self[key] = value
|
|
|
|
def __delattr__(self, key):
|
|
if key in self:
|
|
del self[key]
|
|
return
|
|
raise AttributeError(f"No such attribute '{key}'")
|
|
|
|
|
|
def encode_supervisions(
|
|
supervisions: dict, subsampling_factor: int
|
|
) -> Tuple[torch.Tensor, List[str]]:
|
|
"""
|
|
Encodes Lhotse's ``batch["supervisions"]`` dict into
|
|
a pair of torch Tensor, and a list of transcription strings.
|
|
|
|
The supervision tensor has shape ``(batch_size, 3)``.
|
|
Its second dimension contains information about sequence index [0],
|
|
start frames [1] and num frames [2].
|
|
|
|
The batch items might become re-ordered during this operation -- the
|
|
returned tensor and list of strings are guaranteed to be consistent with
|
|
each other.
|
|
"""
|
|
supervision_segments = torch.stack(
|
|
(
|
|
supervisions["sequence_idx"],
|
|
supervisions["start_frame"] // subsampling_factor,
|
|
supervisions["num_frames"] // subsampling_factor,
|
|
),
|
|
1,
|
|
).to(torch.int32)
|
|
|
|
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
|
supervision_segments = supervision_segments[indices]
|
|
texts = supervisions["text"]
|
|
texts = [texts[idx] for idx in indices]
|
|
|
|
return supervision_segments, texts
|
|
|
|
|
|
def get_texts(
|
|
best_paths: k2.Fsa, return_ragged: bool = False
|
|
) -> Union[List[List[int]], k2.RaggedTensor]:
|
|
"""Extract the texts (as word IDs) from the best-path FSAs.
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful).
|
|
return_ragged:
|
|
True to return a ragged tensor with two axes [utt][word_id].
|
|
False to return a list-of-list word IDs.
|
|
Returns:
|
|
Returns a list of lists of int, containing the label sequences we
|
|
decoded.
|
|
"""
|
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
|
# remove 0's and -1's.
|
|
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
|
# TODO: change arcs.shape() to arcs.shape
|
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
|
|
|
# remove the states and arcs axes.
|
|
aux_shape = aux_shape.remove_axis(1)
|
|
aux_shape = aux_shape.remove_axis(1)
|
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
|
else:
|
|
# remove axis corresponding to states.
|
|
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
|
# remove 0's and -1's.
|
|
aux_labels = aux_labels.remove_values_leq(0)
|
|
|
|
assert aux_labels.num_axes == 2
|
|
if return_ragged:
|
|
return aux_labels
|
|
else:
|
|
return aux_labels.tolist()
|
|
|
|
|
|
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
|
"""Extract labels or aux_labels from the best-path FSAs.
|
|
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful).
|
|
kind:
|
|
Possible values are: "labels" and "aux_labels". Caution: When it is
|
|
"labels", the resulting alignments contain repeats.
|
|
Returns:
|
|
Returns a list of lists of int, containing the token sequences we
|
|
decoded. For `ans[i]`, its length equals to the number of frames
|
|
after subsampling of the i-th utterance in the batch.
|
|
|
|
Example:
|
|
When `kind` is `labels`, one possible alignment example is (with
|
|
repeats)::
|
|
|
|
c c c blk a a blk blk t t t blk blk
|
|
|
|
If `kind` is `aux_labels`, the above example changes to::
|
|
|
|
c blk blk blk a blk blk blk t blk blk blk blk
|
|
|
|
"""
|
|
assert kind in ("labels", "aux_labels")
|
|
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
|
|
token_shape = best_paths.arcs.shape().remove_axis(1)
|
|
# token_shape has axes [fsa][arc]
|
|
tokens = k2.RaggedTensor(
|
|
token_shape, getattr(best_paths, kind).contiguous()
|
|
)
|
|
tokens = tokens.remove_values_eq(-1)
|
|
return tokens.tolist()
|
|
|
|
|
|
def save_alignments(
|
|
alignments: Dict[str, List[int]],
|
|
subsampling_factor: int,
|
|
filename: str,
|
|
) -> None:
|
|
"""Save alignments to a file.
|
|
|
|
Args:
|
|
alignments:
|
|
A dict containing alignments. Keys of the dict are utterances and
|
|
values are the corresponding framewise alignments after subsampling.
|
|
subsampling_factor:
|
|
The subsampling factor of the model.
|
|
filename:
|
|
Path to save the alignments.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
ali_dict = {
|
|
"subsampling_factor": subsampling_factor,
|
|
"alignments": alignments,
|
|
}
|
|
torch.save(ali_dict, filename)
|
|
|
|
|
|
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
|
"""Load alignments from a file.
|
|
|
|
Args:
|
|
filename:
|
|
Path to the file containing alignment information.
|
|
The file should be saved by :func:`save_alignments`.
|
|
Returns:
|
|
Return a tuple containing:
|
|
- subsampling_factor: The subsampling_factor used to compute
|
|
the alignments.
|
|
- alignments: A dict containing utterances and their corresponding
|
|
framewise alignment, after subsampling.
|
|
"""
|
|
ali_dict = torch.load(filename)
|
|
subsampling_factor = ali_dict["subsampling_factor"]
|
|
alignments = ali_dict["alignments"]
|
|
return subsampling_factor, alignments
|
|
|
|
|
|
def store_transcripts(
|
|
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
|
) -> None:
|
|
"""Save predicted results and reference transcripts to a file.
|
|
|
|
Args:
|
|
filename:
|
|
File to save the results to.
|
|
texts:
|
|
An iterable of tuples. The first element is the reference transcript
|
|
while the second element is the predicted result.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
with open(filename, "w") as f:
|
|
for ref, hyp in texts:
|
|
print(f"ref={ref}", file=f)
|
|
print(f"hyp={hyp}", file=f)
|
|
|
|
|
|
def write_error_stats(
|
|
f: TextIO,
|
|
test_set_name: str,
|
|
results: List[Tuple[str, str]],
|
|
enable_log: bool = True,
|
|
) -> float:
|
|
"""Write statistics based on predicted results and reference transcripts.
|
|
|
|
It will write the following to the given file:
|
|
|
|
- WER
|
|
- number of insertions, deletions, substitutions, corrects and total
|
|
reference words. For example::
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
|
reference words (2337 correct)
|
|
|
|
- The difference between the reference transcript and predicted result.
|
|
An instance is given below::
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
|
|
|
The above example shows that the reference word is `EDISON`,
|
|
but it is predicted to `ADDISON` (a substitution error).
|
|
|
|
Another example is::
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK
|
|
|
|
The reference word `SIR` is missing in the predicted
|
|
results (a deletion error).
|
|
results:
|
|
An iterable of tuples. The first element is the reference transcript
|
|
while the second element is the predicted result.
|
|
enable_log:
|
|
If True, also print detailed WER to the console.
|
|
Otherwise, it is written only to the given file.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
ins: Dict[str, int] = defaultdict(int)
|
|
dels: Dict[str, int] = defaultdict(int)
|
|
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
for ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR:
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
elif hyp_word == ERR:
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
elif hyp_word != ref_word:
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
ref_len = sum([len(r) for r, _ in results])
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
|
|
if enable_log:
|
|
logging.info(
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
f"{del_errs} del, {sub_errs} sub ]"
|
|
)
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
print(
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
|
f"words ({num_corr} correct)",
|
|
file=f,
|
|
)
|
|
print(
|
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
for ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
combine_successive_errors = True
|
|
if combine_successive_errors:
|
|
ali = [[[x], [y]] for x, y in ali]
|
|
for i in range(len(ali) - 1):
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
ali[i] = [[], []]
|
|
ali = [
|
|
[
|
|
list(filter(lambda a: a != ERR, x)),
|
|
list(filter(lambda a: a != ERR, y)),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
ali = list(filter(lambda x: x != [[], []], ali))
|
|
ali = [
|
|
[
|
|
ERR if x == [] else " ".join(x),
|
|
ERR if y == [] else " ".join(y),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
|
|
print(
|
|
" ".join(
|
|
(
|
|
ref_word
|
|
if ref_word == hyp_word
|
|
else f"({ref_word}->{hyp_word})"
|
|
for ref_word, hyp_word in ali
|
|
)
|
|
),
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
|
for count, (ref, hyp) in sorted(
|
|
[(v, k) for k, v in subs.items()], reverse=True
|
|
):
|
|
print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("DELETIONS: count ref", file=f)
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
print(f"{count} {ref}", file=f)
|
|
|
|
print("", file=f)
|
|
print("INSERTIONS: count hyp", file=f)
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
print(f"{count} {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print(
|
|
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
|
)
|
|
for _, word, counts in sorted(
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
|
):
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
|
ref_count = corr + ref_sub + dels
|
|
hyp_count = corr + hyp_sub + ins
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
return float(tot_err_rate)
|
|
|
|
|
|
class MetricsTracker(collections.defaultdict):
|
|
def __init__(self):
|
|
# Passing the type 'int' to the base-class constructor
|
|
# makes undefined items default to int() which is zero.
|
|
# This class will play a role as metrics tracker.
|
|
# It can record many metrics, including but not limited to loss.
|
|
super(MetricsTracker, self).__init__(int)
|
|
|
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v
|
|
for k, v in other.items():
|
|
ans[k] = ans[k] + v
|
|
return ans
|
|
|
|
def __mul__(self, alpha: float) -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v * alpha
|
|
return ans
|
|
|
|
def __str__(self) -> str:
|
|
ans = ""
|
|
for k, v in self.norm_items():
|
|
norm_value = "%.4g" % v
|
|
ans += str(k) + "=" + str(norm_value) + ", "
|
|
frames = "%.2f" % self["frames"]
|
|
ans += "over " + str(frames) + " frames."
|
|
return ans
|
|
|
|
def norm_items(self) -> List[Tuple[str, float]]:
|
|
"""
|
|
Returns a list of pairs, like:
|
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
|
"""
|
|
num_frames = self["frames"] if "frames" in self else 1
|
|
ans = []
|
|
for k, v in self.items():
|
|
if k != "frames":
|
|
norm_value = float(v) / num_frames
|
|
ans.append((k, norm_value))
|
|
return ans
|
|
|
|
def reduce(self, device):
|
|
"""
|
|
Reduce using torch.distributed, which I believe ensures that
|
|
all processes get the total.
|
|
"""
|
|
keys = sorted(self.keys())
|
|
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
for k, v in zip(keys, s.cpu().tolist()):
|
|
self[k] = v
|
|
|
|
def write_summary(
|
|
self,
|
|
tb_writer: SummaryWriter,
|
|
prefix: str,
|
|
batch_idx: int,
|
|
) -> None:
|
|
"""Add logging information to a TensorBoard writer.
|
|
|
|
Args:
|
|
tb_writer: a TensorBoard writer
|
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
|
or "train/current_"
|
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
|
"""
|
|
for k, v in self.norm_items():
|
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
|
|
|
|
|
def concat(
|
|
ragged: k2.RaggedTensor, value: int, direction: str
|
|
) -> k2.RaggedTensor:
|
|
"""Prepend a value to the beginning of each sublist or append a value.
|
|
to the end of each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
value:
|
|
The value to prepend or append.
|
|
direction:
|
|
It can be either "left" or "right". If it is "left", we
|
|
prepend the value to the beginning of each sublist;
|
|
if it is "right", we append the value to the end of each
|
|
sublist.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, whose sublists either start with
|
|
or end with the given value.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> concat(a, value=0, direction="left")
|
|
[ [ 0 1 3 ] [ 0 5 ] ]
|
|
>>> concat(a, value=0, direction="right")
|
|
[ [ 1 3 0 ] [ 5 0 ] ]
|
|
|
|
"""
|
|
dtype = ragged.dtype
|
|
device = ragged.device
|
|
|
|
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
|
pad_values = torch.full(
|
|
size=(ragged.tot_size(0), 1),
|
|
fill_value=value,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
pad = k2.RaggedTensor(pad_values)
|
|
|
|
if direction == "left":
|
|
ans = k2.ragged.cat([pad, ragged], axis=1)
|
|
elif direction == "right":
|
|
ans = k2.ragged.cat([ragged, pad], axis=1)
|
|
else:
|
|
raise ValueError(
|
|
f'Unsupported direction: {direction}. " \
|
|
"Expect either "left" or "right"'
|
|
)
|
|
return ans
|
|
|
|
|
|
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
|
|
"""Add SOS to each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
sos_id:
|
|
The ID of the SOS symbol.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, where each sublist starts with SOS.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> add_sos(a, sos_id=0)
|
|
[ [ 0 1 3 ] [ 0 5 ] ]
|
|
|
|
"""
|
|
return concat(ragged, sos_id, direction="left")
|
|
|
|
|
|
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
|
"""Add EOS to each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
eos_id:
|
|
The ID of the EOS symbol.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, where each sublist ends with EOS.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> add_eos(a, eos_id=0)
|
|
[ [ 1 3 0 ] [ 5 0 ] ]
|
|
|
|
"""
|
|
return concat(ragged, eos_id, direction="right")
|
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
lengths:
|
|
A 1-D tensor containing sentence lengths.
|
|
Returns:
|
|
Return a 2-D bool tensor, where masked positions
|
|
are filled with `True` and non-masked positions are
|
|
filled with `False`.
|
|
|
|
>>> lengths = torch.tensor([1, 3, 2, 5])
|
|
>>> make_pad_mask(lengths)
|
|
tensor([[False, True, True, True, True],
|
|
[False, False, False, True, True],
|
|
[False, False, True, True, True],
|
|
[False, False, False, False, False]])
|
|
"""
|
|
assert lengths.ndim == 1, lengths.ndim
|
|
|
|
max_len = lengths.max()
|
|
n = lengths.size(0)
|
|
|
|
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
|
|
|
return expaned_lengths >= lengths.unsqueeze(1)
|
|
|
|
|
|
def l1_norm(x):
|
|
return torch.sum(torch.abs(x))
|
|
|
|
|
|
def l2_norm(x):
|
|
return torch.sum(torch.pow(x, 2))
|
|
|
|
|
|
def linf_norm(x):
|
|
return torch.max(torch.abs(x))
|
|
|
|
|
|
def measure_weight_norms(
|
|
model: nn.Module, norm: str = "l2"
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute the norms of the model's parameters.
|
|
|
|
:param model: a torch.nn.Module instance
|
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
|
|
:return: a dict mapping from parameter's name to its norm.
|
|
"""
|
|
with torch.no_grad():
|
|
norms = {}
|
|
for name, param in model.named_parameters():
|
|
if norm == "l1":
|
|
val = l1_norm(param)
|
|
elif norm == "l2":
|
|
val = l2_norm(param)
|
|
elif norm == "linf":
|
|
val = linf_norm(param)
|
|
else:
|
|
raise ValueError(f"Unknown norm type: {norm}")
|
|
norms[name] = val.item()
|
|
return norms
|
|
|
|
|
|
def measure_gradient_norms(
|
|
model: nn.Module, norm: str = "l1"
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute the norms of the gradients for each of model's parameters.
|
|
|
|
:param model: a torch.nn.Module instance
|
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
|
|
:return: a dict mapping from parameter's name to its gradient's norm.
|
|
"""
|
|
with torch.no_grad():
|
|
norms = {}
|
|
for name, param in model.named_parameters():
|
|
if norm == "l1":
|
|
val = l1_norm(param.grad)
|
|
elif norm == "l2":
|
|
val = l2_norm(param.grad)
|
|
elif norm == "linf":
|
|
val = linf_norm(param.grad)
|
|
else:
|
|
raise ValueError(f"Unknown norm type: {norm}")
|
|
norms[name] = val.item()
|
|
return norms
|
|
|
|
|
|
def optim_step_and_measure_param_change(
|
|
model: nn.Module,
|
|
old_parameters: Dict[str, nn.parameter.Parameter],
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Measure the "relative change in parameters per minibatch."
|
|
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
|
|
and the L2 norm of the original parameter. It is given by the formula:
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
|
|
\end{aligned}
|
|
|
|
This function is supposed to be used as follows:
|
|
|
|
.. code-block:: python
|
|
|
|
old_parameters = {
|
|
n: p.detach().clone() for n, p in model.named_parameters()
|
|
}
|
|
|
|
optimizer.step()
|
|
|
|
deltas = optim_step_and_measure_param_change(old_parameters)
|
|
|
|
Args:
|
|
model: A torch.nn.Module instance.
|
|
old_parameters:
|
|
A Dict of named_parameters before optimizer.step().
|
|
|
|
Return:
|
|
A Dict containing the relative change for each parameter.
|
|
"""
|
|
relative_change = {}
|
|
with torch.no_grad():
|
|
for n, p_new in model.named_parameters():
|
|
p_orig = old_parameters[n]
|
|
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
|
relative_change[n] = delta.item()
|
|
return relative_change
|