icefall/icefall/utils.py
Zengwei Yao b3e6bf66df
Add modified beam search decoding for streaming inference with emformer model (#327)
* Fix torch.nn.Embedding error for torch below 1.8.0

* Changes to fbank computation, use lilcom chunky writer

* Add min in q,k,v of attention

* Remove learnable offset, use relu instead.

* Experiments based on SpecAugment change

* Merge specaug change from Mingshuang.

* Use much more aggressive SpecAug setup

* Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3

* Change p=0.5->0.9, mask_fraction 0.3->0.2

* Change p=0.9 to p=0.8 in SpecAug

* Fix num_time_masks code; revert 0.8 to 0.9

* Change max_frames from 0.2 to 0.15

* Remove ReLU in attention

* Adding diagnostics code...

* Refactor/simplify ConformerEncoder

* First version of rand-combine iterated-training-like idea.

* Improvements to diagnostics (RE those with 1 dim

* Add pelu to this good-performing setup..

* Small bug fixes/imports

* Add baseline for the PeLU expt, keeping only the small normalization-related changes.

* pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units.

* Double learning rate of exp-scale units

* Combine ExpScale and swish for memory reduction

* Add import

* Fix backprop bug

* Fix bug in diagnostics

* Increase scale on Scale from 4 to 20

* Increase scale from 20 to 50.

* Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module

* Reduce scale from 50 to 20

* Add deriv-balancing code

* Double the threshold in brelu; slightly increase max_factor.

* Fix exp dir

* Convert swish nonlinearities to ReLU

* Replace relu with swish-squared.

* Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset.

* Replace norm on input layer with scale of 0.1.

* Extensions to diagnostics code

* Update diagnostics

* Add BasicNorm module

* Replace most normalizations with scales (still have norm in conv)

* Change exp dir

* Replace norm in ConvolutionModule with a scaling factor.

* use nonzero threshold in DerivBalancer

* Add min-abs-value 0.2

* Fix dirname

* Change min-abs threshold from 0.2 to 0.5

* Scale up pos_bias_u and pos_bias_v before use.

* Reduce max_factor to 0.01

* Fix q*scaling logic

* Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code.

* init 1st conv module to smaller variance

* Change how scales are applied; fix residual bug

* Reduce min_abs from 0.5 to 0.2

* Introduce in_scale=0.5 for SwishExpScale

* Fix scale from 0.5 to 2.0 as I really intended..

* Set scaling on SwishExpScale

* Add identity pre_norm_final for diagnostics.

* Add learnable post-scale for mha

* Fix self.post-scale-mha

* Another rework, use scales on linear/conv

* Change dir name

* Reduce initial scaling of modules

* Bug-fix RE bias

* Cosmetic change

* Reduce initial_scale.

* Replace ExpScaleRelu with DoubleSwish()

* DoubleSwish fix

* Use learnable scales for joiner and decoder

* Add max-abs-value constraint in DerivBalancer

* Add max-abs-value

* Change dir name

* Remove ExpScale in feedforward layes.

* Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer.

* Make DoubleSwish more memory efficient

* Reduce constraints from deriv-balancer in ConvModule.

* Add warmup mode

* Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module.

* Add some extra info to diagnostics

* Add deriv-balancer at output of embedding.

* Add more stats.

* Make epsilon in BasicNorm learnable, optionally.

* Draft of 0mean changes..

* Rework of initialization

* Fix typo

* Remove dead code

* Modifying initialization from normal->uniform; add initial_scale when initializing

* bug fix re sqrt

* Remove xscale from pos_embedding

* Remove some dead code.

* Cosmetic changes/renaming things

* Start adding some files..

* Add more files..

* update decode.py file type

* Add remaining files in pruned_transducer_stateless2

* Fix diagnostics-getting code

* Scale down pruned loss in warmup mode

* Reduce warmup scale on pruned loss form 0.1 to 0.01.

* Remove scale_speed, make swish deriv more efficient.

* Cosmetic changes to swish

* Double warm_step

* Fix bug with import

* Change initial std from 0.05 to 0.025.

* Set also scale for embedding to 0.025.

* Remove logging code that broke with newer Lhotse; fix bug with pruned_loss

* Add norm+balancer to VggSubsampling

* Incorporate changes from master into pruned_transducer_stateless2.

* Add max-abs=6, debugged version

* Change 0.025,0.05 to 0.01 in initializations

* Fix balancer code

* Whitespace fix

* Reduce initial pruned_loss scale from 0.01 to 0.0

* Increase warm_step (and valid_interval)

* Change max-abs from 6 to 10

* Change how warmup works.

* Add changes from master to decode.py, train.py

* Simplify the warmup code; max_abs 10->6

* Make warmup work by scaling layer contributions; leave residual layer-drop

* Fix bug

* Fix test mode with random layer dropout

* Add random-number-setting function in dataloader

* Fix/patch how fix_random_seed() is imported.

* Reduce layer-drop prob

* Reduce layer-drop prob after warmup to 1 in 100

* Change power of lr-schedule from -0.5 to -0.333

* Increase model_warm_step to 4k

* Change max-keep-prob to 0.95

* Refactoring and simplifying conformer and frontend

* Rework conformer, remove some code.

* Reduce 1st conv channels from 64 to 32

* Add another convolutional layer

* Fix padding bug

* Remove dropout in output layer

* Reduce speed of some components

* Initial refactoring to remove unnecessary vocab_size

* Fix RE identity

* Bug-fix

* Add final dropout to conformer

* Remove some un-used code

* Replace nn.Linear with ScaledLinear in simple joiner

* Make 2 projections..

* Reduce initial_speed

* Use initial_speed=0.5

* Reduce initial_speed further from 0.5 to 0.25

* Reduce initial_speed from 0.5 to 0.25

* Change how warmup is applied.

* Bug fix to warmup_scale

* Fix test-mode

* Remove final dropout

* Make layer dropout rate 0.075, was 0.1.

* First draft of model rework

* Various bug fixes

* Change learning speed of simple_lm_proj

* Revert transducer_stateless/ to state in upstream/master

* Fix to joiner to allow different dims

* Some cleanups

* Make training more efficient, avoid redoing some projections.

* Change how warm-step is set

* First draft of new approach to learning rates + init

* Some fixes..

* Change initialization to 0.25

* Fix type of parameter

* Fix weight decay formula by adding 1/1-beta

* Fix weight decay formula by adding 1/1-beta

* Fix checkpoint-writing

* Fix to reading scheudler from optim

* Simplified optimizer, rework somet things..

* Reduce model_warm_step from 4k to 3k

* Fix bug in lambda

* Bug-fix RE sign of target_rms

* Changing initial_speed from 0.25 to 01

* Change some defaults in LR-setting rule.

* Remove initial_speed

* Set new scheduler

* Change exponential part of lrate to be epoch based

* Fix bug

* Set 2n rule..

* Implement 2o schedule

* Make lrate rule more symmetric

* Implement 2p version of learning rate schedule.

* Refactor how learning rate is set.

* Fix import

* Modify init (#301)

* update icefall/__init__.py to import more common functions.

* update icefall/__init__.py

* make imports style consistent.

* exclude black check for icefall/__init__.py in pyproject.toml.

* Minor fixes for logging (#296)

* Minor fixes for logging

* Minor fix

* Fix dir names

* Modify beam search to be efficient with current joienr

* Fix adding learning rate to tensorboard

* Fix docs in optim.py

* Support mix precision training on the reworked model (#305)

* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes

* Tedlium3 pruned transducer stateless (#261)

* update tedlium3-pruned-transducer-stateless-codes

* update README.md

* update README.md

* add fast beam search for decoding

* do a change for RESULTS.md

* do a change for RESULTS.md

* do a fix

* do some changes for pruned RNN-T

* Add mix precision support

* Minor fixes

* Minor fixes

* Updating RESULTS.md; fix in beam_search.py

* Fix rebase

* Code style check for librispeech pruned transducer stateless2 (#308)

* Update results for tedlium3 pruned RNN-T (#307)

* Update README.md

* Fix CI errors. (#310)

* Add more results

* Fix tensorboard log location

* Add one more epoch of full expt

* fix comments

* Add results for mixed precision with max-duration 300

* Changes for pretrained.py (tedlium3 pruned RNN-T) (#311)

* GigaSpeech recipe (#120)

* initial commit

* support download, data prep, and fbank

* on-the-fly feature extraction by default

* support BPE based lang

* support HLG for BPE

* small fix

* small fix

* chunked feature extraction by default

* Compute features for GigaSpeech by splitting the manifest.

* Fixes after review.

* Split manifests into 2000 pieces.

* set audio duration mismatch tolerance to 0.01

* small fix

* add conformer training recipe

* Add conformer.py without pre-commit checking

* lazy loading and use SingleCutSampler

* DynamicBucketingSampler

* use KaldifeatFbank to compute fbank for musan

* use pretrained language model and lexicon

* use 3gram to decode, 4gram to rescore

* Add decode.py

* Update .flake8

* Delete compute_fbank_gigaspeech.py

* Use BucketingSampler for valid and test dataloader

* Update params in train.py

* Use bpe_500

* update params in decode.py

* Decrease num_paths while CUDA OOM

* Added README

* Update RESULTS

* black

* Decrease num_paths while CUDA OOM

* Decode with post-processing

* Update results

* Remove lazy_load option

* Use default `storage_type`

* Keep the original tolerance

* Use split-lazy

* black

* Update pretrained model

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Add LG decoding (#277)

* Add LG decoding

* Add log weight pushing

* Minor fixes

* Support computing RNN-T loss with torchaudio (#316)

* Support modified beam search decoding for streaming inference with Emformer model.

* Formatted imports.

* Update results for torchaudio RNN-T. (#322)

* Fixed streaming decoding codes for emformer model.

* Fixed docs.

* Sorted imports for transducer_emformer/streaming_feature_extractor.py

* Minor fix for transducer_emformer/streaming_feature_extractor.py

Co-authored-by: pkufool <wkang@pku.org.cn>
Co-authored-by: Daniel Povey <dpovey@gmail.com>
Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
Co-authored-by: Guo Liyong <guonwpu@qq.com>
Co-authored-by: Wang, Guanbo <wgb14@outlook.com>
2022-04-22 18:06:07 +08:00

799 lines
25 KiB
Python

# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import collections
import logging
import os
import subprocess
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2
import k2.version
import kaldialign
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
Pathlike = Union[str, Path]
@contextmanager
def get_executor():
# We'll either return a process pool or a distributed worker pool.
# Note that this has to be a context manager because we might use multiple
# context manager ("with" clauses) inside, and this way everything will
# free up the resources at the right time.
try:
# If this is executed on the CLSP grid, we will try to use the
# Grid Engine to distribute the tasks.
# Other clusters can also benefit from that, provided a
# cluster-specific wrapper.
# (see https://github.com/pzelasko/plz for reference)
#
# The following must be installed:
# $ pip install dask distributed
# $ pip install git+https://github.com/pzelasko/plz
name = subprocess.check_output("hostname -f", shell=True, text=True)
if name.strip().endswith(".clsp.jhu.edu"):
import plz
from distributed import Client
with plz.setup_cluster() as cluster:
cluster.scale(80)
yield Client(cluster)
return
except Exception:
pass
# No need to return anything - compute_and_store_features
# will just instantiate the pool itself.
yield None
def str2bool(v):
"""Used in argparse.ArgumentParser.add_argument to indicate
that a type is a bool type and user can enter
- yes, true, t, y, 1, to represent True
- no, false, f, n, 0, to represent False
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def setup_logger(
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
) -> None:
"""Setup log level.
Args:
log_filename:
The filename to save the log.
log_level:
The log level to use, e.g., "debug", "info", "warning", "error",
"critical"
"""
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
log_filename = f"{log_filename}-{date_time}-{rank}"
else:
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
log_filename = f"{log_filename}-{date_time}"
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
level = logging.ERROR
if log_level == "debug":
level = logging.DEBUG
elif log_level == "info":
level = logging.INFO
elif log_level == "warning":
level = logging.WARNING
elif log_level == "critical":
level = logging.CRITICAL
logging.basicConfig(
filename=log_filename, format=formatter, level=level, filemode="w"
)
if use_console:
console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
class AttributeDict(dict):
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError(f"No such attribute '{key}'")
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
if key in self:
del self[key]
return
raise AttributeError(f"No such attribute '{key}'")
def encode_supervisions(
supervisions: dict, subsampling_factor: int
) -> Tuple[torch.Tensor, List[str]]:
"""
Encodes Lhotse's ``batch["supervisions"]`` dict into
a pair of torch Tensor, and a list of transcription strings.
The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0],
start frames [1] and num frames [2].
The batch items might become re-ordered during this operation -- the
returned tensor and list of strings are guaranteed to be consistent with
each other.
"""
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor,
supervisions["num_frames"] // subsampling_factor,
),
1,
).to(torch.int32)
indices = torch.argsort(supervision_segments[:, 2], descending=True)
supervision_segments = supervision_segments[indices]
texts = supervisions["text"]
texts = [texts[idx] for idx in indices]
return supervision_segments, texts
def get_texts(
best_paths: k2.Fsa, return_ragged: bool = False
) -> Union[List[List[int]], k2.RaggedTensor]:
"""Extract the texts (as word IDs) from the best-path FSAs.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
return_ragged:
True to return a ragged tensor with two axes [utt][word_id].
False to return a list-of-list word IDs.
Returns:
Returns a list of lists of int, containing the label sequences we
decoded.
"""
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
# remove 0's and -1's.
aux_labels = best_paths.aux_labels.remove_values_leq(0)
# TODO: change arcs.shape() to arcs.shape
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
# remove the states and arcs axes.
aux_shape = aux_shape.remove_axis(1)
aux_shape = aux_shape.remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
else:
# remove axis corresponding to states.
aux_shape = best_paths.arcs.shape().remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
aux_labels = aux_labels.remove_values_leq(0)
assert aux_labels.num_axes == 2
if return_ragged:
return aux_labels
else:
return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
"""Extract labels or aux_labels from the best-path FSAs.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
kind:
Possible values are: "labels" and "aux_labels". Caution: When it is
"labels", the resulting alignments contain repeats.
Returns:
Returns a list of lists of int, containing the token sequences we
decoded. For `ans[i]`, its length equals to the number of frames
after subsampling of the i-th utterance in the batch.
Example:
When `kind` is `labels`, one possible alignment example is (with
repeats)::
c c c blk a a blk blk t t t blk blk
If `kind` is `aux_labels`, the above example changes to::
c blk blk blk a blk blk blk t blk blk blk blk
"""
assert kind in ("labels", "aux_labels")
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
token_shape = best_paths.arcs.shape().remove_axis(1)
# token_shape has axes [fsa][arc]
tokens = k2.RaggedTensor(
token_shape, getattr(best_paths, kind).contiguous()
)
tokens = tokens.remove_values_eq(-1)
return tokens.tolist()
def save_alignments(
alignments: Dict[str, List[int]],
subsampling_factor: int,
filename: str,
) -> None:
"""Save alignments to a file.
Args:
alignments:
A dict containing alignments. Keys of the dict are utterances and
values are the corresponding framewise alignments after subsampling.
subsampling_factor:
The subsampling factor of the model.
filename:
Path to save the alignments.
Returns:
Return None.
"""
ali_dict = {
"subsampling_factor": subsampling_factor,
"alignments": alignments,
}
torch.save(ali_dict, filename)
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
"""Load alignments from a file.
Args:
filename:
Path to the file containing alignment information.
The file should be saved by :func:`save_alignments`.
Returns:
Return a tuple containing:
- subsampling_factor: The subsampling_factor used to compute
the alignments.
- alignments: A dict containing utterances and their corresponding
framewise alignment, after subsampling.
"""
ali_dict = torch.load(filename)
subsampling_factor = ali_dict["subsampling_factor"]
alignments = ali_dict["alignments"]
return subsampling_factor, alignments
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str]]
) -> None:
"""Save predicted results and reference transcripts to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the reference transcript
while the second element is the predicted result.
Returns:
Return None.
"""
with open(filename, "w") as f:
for ref, hyp in texts:
print(f"ref={ref}", file=f)
print(f"hyp={hyp}", file=f)
def write_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the reference transcript
while the second element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
" ".join(
(
ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted(
[(v, k) for k, v in subs.items()], reverse=True
):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print(
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
frames = "%.2f" % self["frames"]
ans += "over " + str(frames) + " frames."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self["frames"] if "frames" in self else 1
ans = []
for k, v in self.items():
if k != "frames":
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def concat(
ragged: k2.RaggedTensor, value: int, direction: str
) -> k2.RaggedTensor:
"""Prepend a value to the beginning of each sublist or append a value.
to the end of each sublist.
Args:
ragged:
A ragged tensor with two axes.
value:
The value to prepend or append.
direction:
It can be either "left" or "right". If it is "left", we
prepend the value to the beginning of each sublist;
if it is "right", we append the value to the end of each
sublist.
Returns:
Return a new ragged tensor, whose sublists either start with
or end with the given value.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> concat(a, value=0, direction="left")
[ [ 0 1 3 ] [ 0 5 ] ]
>>> concat(a, value=0, direction="right")
[ [ 1 3 0 ] [ 5 0 ] ]
"""
dtype = ragged.dtype
device = ragged.device
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
pad_values = torch.full(
size=(ragged.tot_size(0), 1),
fill_value=value,
device=device,
dtype=dtype,
)
pad = k2.RaggedTensor(pad_values)
if direction == "left":
ans = k2.ragged.cat([pad, ragged], axis=1)
elif direction == "right":
ans = k2.ragged.cat([ragged, pad], axis=1)
else:
raise ValueError(
f'Unsupported direction: {direction}. " \
"Expect either "left" or "right"'
)
return ans
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
"""Add SOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
sos_id:
The ID of the SOS symbol.
Returns:
Return a new ragged tensor, where each sublist starts with SOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_sos(a, sos_id=0)
[ [ 0 1 3 ] [ 0 5 ] ]
"""
return concat(ragged, sos_id, direction="left")
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
"""Add EOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
eos_id:
The ID of the EOS symbol.
Returns:
Return a new ragged tensor, where each sublist ends with EOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_eos(a, eos_id=0)
[ [ 1 3 0 ] [ 5 0 ] ]
"""
return concat(ragged, eos_id, direction="right")
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = lengths.max()
n = lengths.size(0)
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
return expaned_lengths >= lengths.unsqueeze(1)
def l1_norm(x):
return torch.sum(torch.abs(x))
def l2_norm(x):
return torch.sum(torch.pow(x, 2))
def linf_norm(x):
return torch.max(torch.abs(x))
def measure_weight_norms(
model: nn.Module, norm: str = "l2"
) -> Dict[str, float]:
"""
Compute the norms of the model's parameters.
:param model: a torch.nn.Module instance
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
:return: a dict mapping from parameter's name to its norm.
"""
with torch.no_grad():
norms = {}
for name, param in model.named_parameters():
if norm == "l1":
val = l1_norm(param)
elif norm == "l2":
val = l2_norm(param)
elif norm == "linf":
val = linf_norm(param)
else:
raise ValueError(f"Unknown norm type: {norm}")
norms[name] = val.item()
return norms
def measure_gradient_norms(
model: nn.Module, norm: str = "l1"
) -> Dict[str, float]:
"""
Compute the norms of the gradients for each of model's parameters.
:param model: a torch.nn.Module instance
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
:return: a dict mapping from parameter's name to its gradient's norm.
"""
with torch.no_grad():
norms = {}
for name, param in model.named_parameters():
if norm == "l1":
val = l1_norm(param.grad)
elif norm == "l2":
val = l2_norm(param.grad)
elif norm == "linf":
val = linf_norm(param.grad)
else:
raise ValueError(f"Unknown norm type: {norm}")
norms[name] = val.item()
return norms
def optim_step_and_measure_param_change(
model: nn.Module,
old_parameters: Dict[str, nn.parameter.Parameter],
) -> Dict[str, float]:
"""
Measure the "relative change in parameters per minibatch."
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
and the L2 norm of the original parameter. It is given by the formula:
.. math::
\begin{aligned}
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
\end{aligned}
This function is supposed to be used as follows:
.. code-block:: python
old_parameters = {
n: p.detach().clone() for n, p in model.named_parameters()
}
optimizer.step()
deltas = optim_step_and_measure_param_change(old_parameters)
Args:
model: A torch.nn.Module instance.
old_parameters:
A Dict of named_parameters before optimizer.step().
Return:
A Dict containing the relative change for each parameter.
"""
relative_change = {}
with torch.no_grad():
for n, p_new in model.named_parameters():
p_orig = old_parameters[n]
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item()
return relative_change