mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 09:34:39 +00:00
* Fix torch.nn.Embedding error for torch below 1.8.0 * Changes to fbank computation, use lilcom chunky writer * Add min in q,k,v of attention * Remove learnable offset, use relu instead. * Experiments based on SpecAugment change * Merge specaug change from Mingshuang. * Use much more aggressive SpecAug setup * Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3 * Change p=0.5->0.9, mask_fraction 0.3->0.2 * Change p=0.9 to p=0.8 in SpecAug * Fix num_time_masks code; revert 0.8 to 0.9 * Change max_frames from 0.2 to 0.15 * Remove ReLU in attention * Adding diagnostics code... * Refactor/simplify ConformerEncoder * First version of rand-combine iterated-training-like idea. * Improvements to diagnostics (RE those with 1 dim * Add pelu to this good-performing setup.. * Small bug fixes/imports * Add baseline for the PeLU expt, keeping only the small normalization-related changes. * pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units. * Double learning rate of exp-scale units * Combine ExpScale and swish for memory reduction * Add import * Fix backprop bug * Fix bug in diagnostics * Increase scale on Scale from 4 to 20 * Increase scale from 20 to 50. * Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module * Reduce scale from 50 to 20 * Add deriv-balancing code * Double the threshold in brelu; slightly increase max_factor. * Fix exp dir * Convert swish nonlinearities to ReLU * Replace relu with swish-squared. * Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset. * Replace norm on input layer with scale of 0.1. * Extensions to diagnostics code * Update diagnostics * Add BasicNorm module * Replace most normalizations with scales (still have norm in conv) * Change exp dir * Replace norm in ConvolutionModule with a scaling factor. * use nonzero threshold in DerivBalancer * Add min-abs-value 0.2 * Fix dirname * Change min-abs threshold from 0.2 to 0.5 * Scale up pos_bias_u and pos_bias_v before use. * Reduce max_factor to 0.01 * Fix q*scaling logic * Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code. * init 1st conv module to smaller variance * Change how scales are applied; fix residual bug * Reduce min_abs from 0.5 to 0.2 * Introduce in_scale=0.5 for SwishExpScale * Fix scale from 0.5 to 2.0 as I really intended.. * Set scaling on SwishExpScale * Add identity pre_norm_final for diagnostics. * Add learnable post-scale for mha * Fix self.post-scale-mha * Another rework, use scales on linear/conv * Change dir name * Reduce initial scaling of modules * Bug-fix RE bias * Cosmetic change * Reduce initial_scale. * Replace ExpScaleRelu with DoubleSwish() * DoubleSwish fix * Use learnable scales for joiner and decoder * Add max-abs-value constraint in DerivBalancer * Add max-abs-value * Change dir name * Remove ExpScale in feedforward layes. * Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer. * Make DoubleSwish more memory efficient * Reduce constraints from deriv-balancer in ConvModule. * Add warmup mode * Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module. * Add some extra info to diagnostics * Add deriv-balancer at output of embedding. * Add more stats. * Make epsilon in BasicNorm learnable, optionally. * Draft of 0mean changes.. * Rework of initialization * Fix typo * Remove dead code * Modifying initialization from normal->uniform; add initial_scale when initializing * bug fix re sqrt * Remove xscale from pos_embedding * Remove some dead code. * Cosmetic changes/renaming things * Start adding some files.. * Add more files.. * update decode.py file type * Add remaining files in pruned_transducer_stateless2 * Fix diagnostics-getting code * Scale down pruned loss in warmup mode * Reduce warmup scale on pruned loss form 0.1 to 0.01. * Remove scale_speed, make swish deriv more efficient. * Cosmetic changes to swish * Double warm_step * Fix bug with import * Change initial std from 0.05 to 0.025. * Set also scale for embedding to 0.025. * Remove logging code that broke with newer Lhotse; fix bug with pruned_loss * Add norm+balancer to VggSubsampling * Incorporate changes from master into pruned_transducer_stateless2. * Add max-abs=6, debugged version * Change 0.025,0.05 to 0.01 in initializations * Fix balancer code * Whitespace fix * Reduce initial pruned_loss scale from 0.01 to 0.0 * Increase warm_step (and valid_interval) * Change max-abs from 6 to 10 * Change how warmup works. * Add changes from master to decode.py, train.py * Simplify the warmup code; max_abs 10->6 * Make warmup work by scaling layer contributions; leave residual layer-drop * Fix bug * Fix test mode with random layer dropout * Add random-number-setting function in dataloader * Fix/patch how fix_random_seed() is imported. * Reduce layer-drop prob * Reduce layer-drop prob after warmup to 1 in 100 * Change power of lr-schedule from -0.5 to -0.333 * Increase model_warm_step to 4k * Change max-keep-prob to 0.95 * Refactoring and simplifying conformer and frontend * Rework conformer, remove some code. * Reduce 1st conv channels from 64 to 32 * Add another convolutional layer * Fix padding bug * Remove dropout in output layer * Reduce speed of some components * Initial refactoring to remove unnecessary vocab_size * Fix RE identity * Bug-fix * Add final dropout to conformer * Remove some un-used code * Replace nn.Linear with ScaledLinear in simple joiner * Make 2 projections.. * Reduce initial_speed * Use initial_speed=0.5 * Reduce initial_speed further from 0.5 to 0.25 * Reduce initial_speed from 0.5 to 0.25 * Change how warmup is applied. * Bug fix to warmup_scale * Fix test-mode * Remove final dropout * Make layer dropout rate 0.075, was 0.1. * First draft of model rework * Various bug fixes * Change learning speed of simple_lm_proj * Revert transducer_stateless/ to state in upstream/master * Fix to joiner to allow different dims * Some cleanups * Make training more efficient, avoid redoing some projections. * Change how warm-step is set * First draft of new approach to learning rates + init * Some fixes.. * Change initialization to 0.25 * Fix type of parameter * Fix weight decay formula by adding 1/1-beta * Fix weight decay formula by adding 1/1-beta * Fix checkpoint-writing * Fix to reading scheudler from optim * Simplified optimizer, rework somet things.. * Reduce model_warm_step from 4k to 3k * Fix bug in lambda * Bug-fix RE sign of target_rms * Changing initial_speed from 0.25 to 01 * Change some defaults in LR-setting rule. * Remove initial_speed * Set new scheduler * Change exponential part of lrate to be epoch based * Fix bug * Set 2n rule.. * Implement 2o schedule * Make lrate rule more symmetric * Implement 2p version of learning rate schedule. * Refactor how learning rate is set. * Fix import * Modify init (#301) * update icefall/__init__.py to import more common functions. * update icefall/__init__.py * make imports style consistent. * exclude black check for icefall/__init__.py in pyproject.toml. * Minor fixes for logging (#296) * Minor fixes for logging * Minor fix * Fix dir names * Modify beam search to be efficient with current joienr * Fix adding learning rate to tensorboard * Fix docs in optim.py * Support mix precision training on the reworked model (#305) * Add mix precision support * Minor fixes * Minor fixes * Minor fixes * Tedlium3 pruned transducer stateless (#261) * update tedlium3-pruned-transducer-stateless-codes * update README.md * update README.md * add fast beam search for decoding * do a change for RESULTS.md * do a change for RESULTS.md * do a fix * do some changes for pruned RNN-T * Add mix precision support * Minor fixes * Minor fixes * Updating RESULTS.md; fix in beam_search.py * Fix rebase * Code style check for librispeech pruned transducer stateless2 (#308) * Update results for tedlium3 pruned RNN-T (#307) * Update README.md * Fix CI errors. (#310) * Add more results * Fix tensorboard log location * Add one more epoch of full expt * fix comments * Add results for mixed precision with max-duration 300 * Changes for pretrained.py (tedlium3 pruned RNN-T) (#311) * GigaSpeech recipe (#120) * initial commit * support download, data prep, and fbank * on-the-fly feature extraction by default * support BPE based lang * support HLG for BPE * small fix * small fix * chunked feature extraction by default * Compute features for GigaSpeech by splitting the manifest. * Fixes after review. * Split manifests into 2000 pieces. * set audio duration mismatch tolerance to 0.01 * small fix * add conformer training recipe * Add conformer.py without pre-commit checking * lazy loading and use SingleCutSampler * DynamicBucketingSampler * use KaldifeatFbank to compute fbank for musan * use pretrained language model and lexicon * use 3gram to decode, 4gram to rescore * Add decode.py * Update .flake8 * Delete compute_fbank_gigaspeech.py * Use BucketingSampler for valid and test dataloader * Update params in train.py * Use bpe_500 * update params in decode.py * Decrease num_paths while CUDA OOM * Added README * Update RESULTS * black * Decrease num_paths while CUDA OOM * Decode with post-processing * Update results * Remove lazy_load option * Use default `storage_type` * Keep the original tolerance * Use split-lazy * black * Update pretrained model Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Add LG decoding (#277) * Add LG decoding * Add log weight pushing * Minor fixes * Support computing RNN-T loss with torchaudio (#316) * Support modified beam search decoding for streaming inference with Emformer model. * Formatted imports. * Update results for torchaudio RNN-T. (#322) * Fixed streaming decoding codes for emformer model. * Fixed docs. * Sorted imports for transducer_emformer/streaming_feature_extractor.py * Minor fix for transducer_emformer/streaming_feature_extractor.py Co-authored-by: pkufool <wkang@pku.org.cn> Co-authored-by: Daniel Povey <dpovey@gmail.com> Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Guo Liyong <guonwpu@qq.com> Co-authored-by: Wang, Guanbo <wgb14@outlook.com>
799 lines
25 KiB
Python
799 lines
25 KiB
Python
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
|
# Mingshuang Luo)
|
|
#
|
|
# See ../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import argparse
|
|
import collections
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
|
|
|
import k2
|
|
import k2.version
|
|
import kaldialign
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
Pathlike = Union[str, Path]
|
|
|
|
|
|
@contextmanager
|
|
def get_executor():
|
|
# We'll either return a process pool or a distributed worker pool.
|
|
# Note that this has to be a context manager because we might use multiple
|
|
# context manager ("with" clauses) inside, and this way everything will
|
|
# free up the resources at the right time.
|
|
try:
|
|
# If this is executed on the CLSP grid, we will try to use the
|
|
# Grid Engine to distribute the tasks.
|
|
# Other clusters can also benefit from that, provided a
|
|
# cluster-specific wrapper.
|
|
# (see https://github.com/pzelasko/plz for reference)
|
|
#
|
|
# The following must be installed:
|
|
# $ pip install dask distributed
|
|
# $ pip install git+https://github.com/pzelasko/plz
|
|
name = subprocess.check_output("hostname -f", shell=True, text=True)
|
|
if name.strip().endswith(".clsp.jhu.edu"):
|
|
import plz
|
|
from distributed import Client
|
|
|
|
with plz.setup_cluster() as cluster:
|
|
cluster.scale(80)
|
|
yield Client(cluster)
|
|
return
|
|
except Exception:
|
|
pass
|
|
# No need to return anything - compute_and_store_features
|
|
# will just instantiate the pool itself.
|
|
yield None
|
|
|
|
|
|
def str2bool(v):
|
|
"""Used in argparse.ArgumentParser.add_argument to indicate
|
|
that a type is a bool type and user can enter
|
|
|
|
- yes, true, t, y, 1, to represent True
|
|
- no, false, f, n, 0, to represent False
|
|
|
|
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
|
"""
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ("yes", "true", "t", "y", "1"):
|
|
return True
|
|
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
|
|
|
|
|
def setup_logger(
|
|
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
|
|
) -> None:
|
|
"""Setup log level.
|
|
|
|
Args:
|
|
log_filename:
|
|
The filename to save the log.
|
|
log_level:
|
|
The log level to use, e.g., "debug", "info", "warning", "error",
|
|
"critical"
|
|
"""
|
|
now = datetime.now()
|
|
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
|
if dist.is_available() and dist.is_initialized():
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
|
log_filename = f"{log_filename}-{date_time}-{rank}"
|
|
else:
|
|
formatter = (
|
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
)
|
|
log_filename = f"{log_filename}-{date_time}"
|
|
|
|
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
|
|
|
level = logging.ERROR
|
|
if log_level == "debug":
|
|
level = logging.DEBUG
|
|
elif log_level == "info":
|
|
level = logging.INFO
|
|
elif log_level == "warning":
|
|
level = logging.WARNING
|
|
elif log_level == "critical":
|
|
level = logging.CRITICAL
|
|
|
|
logging.basicConfig(
|
|
filename=log_filename, format=formatter, level=level, filemode="w"
|
|
)
|
|
if use_console:
|
|
console = logging.StreamHandler()
|
|
console.setLevel(level)
|
|
console.setFormatter(logging.Formatter(formatter))
|
|
logging.getLogger("").addHandler(console)
|
|
|
|
|
|
class AttributeDict(dict):
|
|
def __getattr__(self, key):
|
|
if key in self:
|
|
return self[key]
|
|
raise AttributeError(f"No such attribute '{key}'")
|
|
|
|
def __setattr__(self, key, value):
|
|
self[key] = value
|
|
|
|
def __delattr__(self, key):
|
|
if key in self:
|
|
del self[key]
|
|
return
|
|
raise AttributeError(f"No such attribute '{key}'")
|
|
|
|
|
|
def encode_supervisions(
|
|
supervisions: dict, subsampling_factor: int
|
|
) -> Tuple[torch.Tensor, List[str]]:
|
|
"""
|
|
Encodes Lhotse's ``batch["supervisions"]`` dict into
|
|
a pair of torch Tensor, and a list of transcription strings.
|
|
|
|
The supervision tensor has shape ``(batch_size, 3)``.
|
|
Its second dimension contains information about sequence index [0],
|
|
start frames [1] and num frames [2].
|
|
|
|
The batch items might become re-ordered during this operation -- the
|
|
returned tensor and list of strings are guaranteed to be consistent with
|
|
each other.
|
|
"""
|
|
supervision_segments = torch.stack(
|
|
(
|
|
supervisions["sequence_idx"],
|
|
supervisions["start_frame"] // subsampling_factor,
|
|
supervisions["num_frames"] // subsampling_factor,
|
|
),
|
|
1,
|
|
).to(torch.int32)
|
|
|
|
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
|
supervision_segments = supervision_segments[indices]
|
|
texts = supervisions["text"]
|
|
texts = [texts[idx] for idx in indices]
|
|
|
|
return supervision_segments, texts
|
|
|
|
|
|
def get_texts(
|
|
best_paths: k2.Fsa, return_ragged: bool = False
|
|
) -> Union[List[List[int]], k2.RaggedTensor]:
|
|
"""Extract the texts (as word IDs) from the best-path FSAs.
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful).
|
|
return_ragged:
|
|
True to return a ragged tensor with two axes [utt][word_id].
|
|
False to return a list-of-list word IDs.
|
|
Returns:
|
|
Returns a list of lists of int, containing the label sequences we
|
|
decoded.
|
|
"""
|
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
|
# remove 0's and -1's.
|
|
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
|
# TODO: change arcs.shape() to arcs.shape
|
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
|
|
|
# remove the states and arcs axes.
|
|
aux_shape = aux_shape.remove_axis(1)
|
|
aux_shape = aux_shape.remove_axis(1)
|
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
|
else:
|
|
# remove axis corresponding to states.
|
|
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
|
# remove 0's and -1's.
|
|
aux_labels = aux_labels.remove_values_leq(0)
|
|
|
|
assert aux_labels.num_axes == 2
|
|
if return_ragged:
|
|
return aux_labels
|
|
else:
|
|
return aux_labels.tolist()
|
|
|
|
|
|
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
|
"""Extract labels or aux_labels from the best-path FSAs.
|
|
|
|
Args:
|
|
best_paths:
|
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
|
containing multiple FSAs, which is expected to be the result
|
|
of k2.shortest_path (otherwise the returned values won't
|
|
be meaningful).
|
|
kind:
|
|
Possible values are: "labels" and "aux_labels". Caution: When it is
|
|
"labels", the resulting alignments contain repeats.
|
|
Returns:
|
|
Returns a list of lists of int, containing the token sequences we
|
|
decoded. For `ans[i]`, its length equals to the number of frames
|
|
after subsampling of the i-th utterance in the batch.
|
|
|
|
Example:
|
|
When `kind` is `labels`, one possible alignment example is (with
|
|
repeats)::
|
|
|
|
c c c blk a a blk blk t t t blk blk
|
|
|
|
If `kind` is `aux_labels`, the above example changes to::
|
|
|
|
c blk blk blk a blk blk blk t blk blk blk blk
|
|
|
|
"""
|
|
assert kind in ("labels", "aux_labels")
|
|
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
|
|
token_shape = best_paths.arcs.shape().remove_axis(1)
|
|
# token_shape has axes [fsa][arc]
|
|
tokens = k2.RaggedTensor(
|
|
token_shape, getattr(best_paths, kind).contiguous()
|
|
)
|
|
tokens = tokens.remove_values_eq(-1)
|
|
return tokens.tolist()
|
|
|
|
|
|
def save_alignments(
|
|
alignments: Dict[str, List[int]],
|
|
subsampling_factor: int,
|
|
filename: str,
|
|
) -> None:
|
|
"""Save alignments to a file.
|
|
|
|
Args:
|
|
alignments:
|
|
A dict containing alignments. Keys of the dict are utterances and
|
|
values are the corresponding framewise alignments after subsampling.
|
|
subsampling_factor:
|
|
The subsampling factor of the model.
|
|
filename:
|
|
Path to save the alignments.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
ali_dict = {
|
|
"subsampling_factor": subsampling_factor,
|
|
"alignments": alignments,
|
|
}
|
|
torch.save(ali_dict, filename)
|
|
|
|
|
|
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
|
"""Load alignments from a file.
|
|
|
|
Args:
|
|
filename:
|
|
Path to the file containing alignment information.
|
|
The file should be saved by :func:`save_alignments`.
|
|
Returns:
|
|
Return a tuple containing:
|
|
- subsampling_factor: The subsampling_factor used to compute
|
|
the alignments.
|
|
- alignments: A dict containing utterances and their corresponding
|
|
framewise alignment, after subsampling.
|
|
"""
|
|
ali_dict = torch.load(filename)
|
|
subsampling_factor = ali_dict["subsampling_factor"]
|
|
alignments = ali_dict["alignments"]
|
|
return subsampling_factor, alignments
|
|
|
|
|
|
def store_transcripts(
|
|
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
|
) -> None:
|
|
"""Save predicted results and reference transcripts to a file.
|
|
|
|
Args:
|
|
filename:
|
|
File to save the results to.
|
|
texts:
|
|
An iterable of tuples. The first element is the reference transcript
|
|
while the second element is the predicted result.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
with open(filename, "w") as f:
|
|
for ref, hyp in texts:
|
|
print(f"ref={ref}", file=f)
|
|
print(f"hyp={hyp}", file=f)
|
|
|
|
|
|
def write_error_stats(
|
|
f: TextIO,
|
|
test_set_name: str,
|
|
results: List[Tuple[str, str]],
|
|
enable_log: bool = True,
|
|
) -> float:
|
|
"""Write statistics based on predicted results and reference transcripts.
|
|
|
|
It will write the following to the given file:
|
|
|
|
- WER
|
|
- number of insertions, deletions, substitutions, corrects and total
|
|
reference words. For example::
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
|
reference words (2337 correct)
|
|
|
|
- The difference between the reference transcript and predicted result.
|
|
An instance is given below::
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
|
|
|
The above example shows that the reference word is `EDISON`,
|
|
but it is predicted to `ADDISON` (a substitution error).
|
|
|
|
Another example is::
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK
|
|
|
|
The reference word `SIR` is missing in the predicted
|
|
results (a deletion error).
|
|
results:
|
|
An iterable of tuples. The first element is the reference transcript
|
|
while the second element is the predicted result.
|
|
enable_log:
|
|
If True, also print detailed WER to the console.
|
|
Otherwise, it is written only to the given file.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
ins: Dict[str, int] = defaultdict(int)
|
|
dels: Dict[str, int] = defaultdict(int)
|
|
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
for ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR:
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
elif hyp_word == ERR:
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
elif hyp_word != ref_word:
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
ref_len = sum([len(r) for r, _ in results])
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
|
|
if enable_log:
|
|
logging.info(
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
f"{del_errs} del, {sub_errs} sub ]"
|
|
)
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
print(
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
|
f"words ({num_corr} correct)",
|
|
file=f,
|
|
)
|
|
print(
|
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
for ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
combine_successive_errors = True
|
|
if combine_successive_errors:
|
|
ali = [[[x], [y]] for x, y in ali]
|
|
for i in range(len(ali) - 1):
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
ali[i] = [[], []]
|
|
ali = [
|
|
[
|
|
list(filter(lambda a: a != ERR, x)),
|
|
list(filter(lambda a: a != ERR, y)),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
ali = list(filter(lambda x: x != [[], []], ali))
|
|
ali = [
|
|
[
|
|
ERR if x == [] else " ".join(x),
|
|
ERR if y == [] else " ".join(y),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
|
|
print(
|
|
" ".join(
|
|
(
|
|
ref_word
|
|
if ref_word == hyp_word
|
|
else f"({ref_word}->{hyp_word})"
|
|
for ref_word, hyp_word in ali
|
|
)
|
|
),
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
|
for count, (ref, hyp) in sorted(
|
|
[(v, k) for k, v in subs.items()], reverse=True
|
|
):
|
|
print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("DELETIONS: count ref", file=f)
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
print(f"{count} {ref}", file=f)
|
|
|
|
print("", file=f)
|
|
print("INSERTIONS: count hyp", file=f)
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
print(f"{count} {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print(
|
|
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
|
)
|
|
for _, word, counts in sorted(
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
|
):
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
|
ref_count = corr + ref_sub + dels
|
|
hyp_count = corr + hyp_sub + ins
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
return float(tot_err_rate)
|
|
|
|
|
|
class MetricsTracker(collections.defaultdict):
|
|
def __init__(self):
|
|
# Passing the type 'int' to the base-class constructor
|
|
# makes undefined items default to int() which is zero.
|
|
# This class will play a role as metrics tracker.
|
|
# It can record many metrics, including but not limited to loss.
|
|
super(MetricsTracker, self).__init__(int)
|
|
|
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v
|
|
for k, v in other.items():
|
|
ans[k] = ans[k] + v
|
|
return ans
|
|
|
|
def __mul__(self, alpha: float) -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v * alpha
|
|
return ans
|
|
|
|
def __str__(self) -> str:
|
|
ans = ""
|
|
for k, v in self.norm_items():
|
|
norm_value = "%.4g" % v
|
|
ans += str(k) + "=" + str(norm_value) + ", "
|
|
frames = "%.2f" % self["frames"]
|
|
ans += "over " + str(frames) + " frames."
|
|
return ans
|
|
|
|
def norm_items(self) -> List[Tuple[str, float]]:
|
|
"""
|
|
Returns a list of pairs, like:
|
|
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
|
"""
|
|
num_frames = self["frames"] if "frames" in self else 1
|
|
ans = []
|
|
for k, v in self.items():
|
|
if k != "frames":
|
|
norm_value = float(v) / num_frames
|
|
ans.append((k, norm_value))
|
|
return ans
|
|
|
|
def reduce(self, device):
|
|
"""
|
|
Reduce using torch.distributed, which I believe ensures that
|
|
all processes get the total.
|
|
"""
|
|
keys = sorted(self.keys())
|
|
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
for k, v in zip(keys, s.cpu().tolist()):
|
|
self[k] = v
|
|
|
|
def write_summary(
|
|
self,
|
|
tb_writer: SummaryWriter,
|
|
prefix: str,
|
|
batch_idx: int,
|
|
) -> None:
|
|
"""Add logging information to a TensorBoard writer.
|
|
|
|
Args:
|
|
tb_writer: a TensorBoard writer
|
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
|
or "train/current_"
|
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
|
"""
|
|
for k, v in self.norm_items():
|
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
|
|
|
|
|
def concat(
|
|
ragged: k2.RaggedTensor, value: int, direction: str
|
|
) -> k2.RaggedTensor:
|
|
"""Prepend a value to the beginning of each sublist or append a value.
|
|
to the end of each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
value:
|
|
The value to prepend or append.
|
|
direction:
|
|
It can be either "left" or "right". If it is "left", we
|
|
prepend the value to the beginning of each sublist;
|
|
if it is "right", we append the value to the end of each
|
|
sublist.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, whose sublists either start with
|
|
or end with the given value.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> concat(a, value=0, direction="left")
|
|
[ [ 0 1 3 ] [ 0 5 ] ]
|
|
>>> concat(a, value=0, direction="right")
|
|
[ [ 1 3 0 ] [ 5 0 ] ]
|
|
|
|
"""
|
|
dtype = ragged.dtype
|
|
device = ragged.device
|
|
|
|
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
|
pad_values = torch.full(
|
|
size=(ragged.tot_size(0), 1),
|
|
fill_value=value,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
pad = k2.RaggedTensor(pad_values)
|
|
|
|
if direction == "left":
|
|
ans = k2.ragged.cat([pad, ragged], axis=1)
|
|
elif direction == "right":
|
|
ans = k2.ragged.cat([ragged, pad], axis=1)
|
|
else:
|
|
raise ValueError(
|
|
f'Unsupported direction: {direction}. " \
|
|
"Expect either "left" or "right"'
|
|
)
|
|
return ans
|
|
|
|
|
|
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
|
|
"""Add SOS to each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
sos_id:
|
|
The ID of the SOS symbol.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, where each sublist starts with SOS.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> add_sos(a, sos_id=0)
|
|
[ [ 0 1 3 ] [ 0 5 ] ]
|
|
|
|
"""
|
|
return concat(ragged, sos_id, direction="left")
|
|
|
|
|
|
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
|
"""Add EOS to each sublist.
|
|
|
|
Args:
|
|
ragged:
|
|
A ragged tensor with two axes.
|
|
eos_id:
|
|
The ID of the EOS symbol.
|
|
|
|
Returns:
|
|
Return a new ragged tensor, where each sublist ends with EOS.
|
|
|
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
|
>>> a
|
|
[ [ 1 3 ] [ 5 ] ]
|
|
>>> add_eos(a, eos_id=0)
|
|
[ [ 1 3 0 ] [ 5 0 ] ]
|
|
|
|
"""
|
|
return concat(ragged, eos_id, direction="right")
|
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
lengths:
|
|
A 1-D tensor containing sentence lengths.
|
|
Returns:
|
|
Return a 2-D bool tensor, where masked positions
|
|
are filled with `True` and non-masked positions are
|
|
filled with `False`.
|
|
|
|
>>> lengths = torch.tensor([1, 3, 2, 5])
|
|
>>> make_pad_mask(lengths)
|
|
tensor([[False, True, True, True, True],
|
|
[False, False, False, True, True],
|
|
[False, False, True, True, True],
|
|
[False, False, False, False, False]])
|
|
"""
|
|
assert lengths.ndim == 1, lengths.ndim
|
|
|
|
max_len = lengths.max()
|
|
n = lengths.size(0)
|
|
|
|
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
|
|
|
return expaned_lengths >= lengths.unsqueeze(1)
|
|
|
|
|
|
def l1_norm(x):
|
|
return torch.sum(torch.abs(x))
|
|
|
|
|
|
def l2_norm(x):
|
|
return torch.sum(torch.pow(x, 2))
|
|
|
|
|
|
def linf_norm(x):
|
|
return torch.max(torch.abs(x))
|
|
|
|
|
|
def measure_weight_norms(
|
|
model: nn.Module, norm: str = "l2"
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute the norms of the model's parameters.
|
|
|
|
:param model: a torch.nn.Module instance
|
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
|
|
:return: a dict mapping from parameter's name to its norm.
|
|
"""
|
|
with torch.no_grad():
|
|
norms = {}
|
|
for name, param in model.named_parameters():
|
|
if norm == "l1":
|
|
val = l1_norm(param)
|
|
elif norm == "l2":
|
|
val = l2_norm(param)
|
|
elif norm == "linf":
|
|
val = linf_norm(param)
|
|
else:
|
|
raise ValueError(f"Unknown norm type: {norm}")
|
|
norms[name] = val.item()
|
|
return norms
|
|
|
|
|
|
def measure_gradient_norms(
|
|
model: nn.Module, norm: str = "l1"
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute the norms of the gradients for each of model's parameters.
|
|
|
|
:param model: a torch.nn.Module instance
|
|
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
|
|
:return: a dict mapping from parameter's name to its gradient's norm.
|
|
"""
|
|
with torch.no_grad():
|
|
norms = {}
|
|
for name, param in model.named_parameters():
|
|
if norm == "l1":
|
|
val = l1_norm(param.grad)
|
|
elif norm == "l2":
|
|
val = l2_norm(param.grad)
|
|
elif norm == "linf":
|
|
val = linf_norm(param.grad)
|
|
else:
|
|
raise ValueError(f"Unknown norm type: {norm}")
|
|
norms[name] = val.item()
|
|
return norms
|
|
|
|
|
|
def optim_step_and_measure_param_change(
|
|
model: nn.Module,
|
|
old_parameters: Dict[str, nn.parameter.Parameter],
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Measure the "relative change in parameters per minibatch."
|
|
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
|
|
and the L2 norm of the original parameter. It is given by the formula:
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
|
|
\end{aligned}
|
|
|
|
This function is supposed to be used as follows:
|
|
|
|
.. code-block:: python
|
|
|
|
old_parameters = {
|
|
n: p.detach().clone() for n, p in model.named_parameters()
|
|
}
|
|
|
|
optimizer.step()
|
|
|
|
deltas = optim_step_and_measure_param_change(old_parameters)
|
|
|
|
Args:
|
|
model: A torch.nn.Module instance.
|
|
old_parameters:
|
|
A Dict of named_parameters before optimizer.step().
|
|
|
|
Return:
|
|
A Dict containing the relative change for each parameter.
|
|
"""
|
|
relative_change = {}
|
|
with torch.no_grad():
|
|
for n, p_new in model.named_parameters():
|
|
p_orig = old_parameters[n]
|
|
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
|
relative_change[n] = delta.item()
|
|
return relative_change
|