mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
formatted the entire LibriSpeech recipe (#1270)
* formatted the entire librispeech recipe * minor updates
This commit is contained in:
parent
ef658d691e
commit
ef5da4824d
@ -557,7 +557,6 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
@ -43,6 +43,7 @@ from pathlib import Path
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
# This function is copied from lhotse
|
||||
def tqdm_urlretrieve_hook(t):
|
||||
"""Wraps tqdm instance.
|
||||
|
@ -236,7 +236,7 @@ def greedy_search_batch(
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -507,7 +507,7 @@ def modified_beam_search(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
|
@ -162,7 +162,6 @@ def merge_chunks(
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
|
||||
for cut in cuts_chunk:
|
||||
cur_rec_id = cut.recording.id
|
||||
if len(cut_list) == 0:
|
||||
|
@ -264,6 +264,7 @@ def decode_dataset(
|
||||
- timestamps of reference transcript
|
||||
- timestamps of predicted result
|
||||
"""
|
||||
|
||||
# Background worker to add alignemnt and save cuts to disk.
|
||||
def _save_worker(
|
||||
cuts: List[Cut],
|
||||
|
@ -66,7 +66,6 @@ class Eve(Optimizer):
|
||||
weight_decay=1e-3,
|
||||
target_rms=0.1,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -719,7 +719,7 @@ def greedy_search_batch(
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -1019,7 +1019,7 @@ def modified_beam_search(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -1227,7 +1227,7 @@ def modified_beam_search_lm_rescore(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -1427,7 +1427,7 @@ def modified_beam_search_lm_rescore_LODR(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -2608,7 +2608,6 @@ def modified_beam_search_LODR(
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
if context_graph is not None:
|
||||
(
|
||||
context_score,
|
||||
@ -2758,7 +2757,7 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||
@ -2900,7 +2899,6 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
|
@ -66,7 +66,6 @@ class Eve(Optimizer):
|
||||
weight_decay=1e-3,
|
||||
target_rms=0.1,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -528,7 +528,6 @@ class ScaledLSTM(nn.LSTM):
|
||||
return
|
||||
|
||||
with torch.cuda.device_of(first_fw):
|
||||
|
||||
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
||||
# an inplace operation on self._flat_weights
|
||||
with torch.no_grad():
|
||||
|
@ -56,7 +56,6 @@ class CodebookIndexExtractor:
|
||||
"""
|
||||
|
||||
def __init__(self, params: AttributeDict):
|
||||
|
||||
self.params = params
|
||||
params.subsets = ["clean-100"]
|
||||
if self.params.full_libri:
|
||||
|
@ -111,7 +111,7 @@ def batch_force_alignment(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
for t, batch_size in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
|
@ -543,7 +543,6 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
@ -463,7 +463,6 @@ def train_one_epoch(
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
@ -513,7 +513,6 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
@ -517,7 +517,6 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
|
@ -70,7 +70,7 @@ class PiecewiseLinear(object):
|
||||
self.pairs = list(args[0].pairs)
|
||||
else:
|
||||
self.pairs = [(float(x), float(y)) for x, y in args]
|
||||
for (x, y) in self.pairs:
|
||||
for x, y in self.pairs:
|
||||
assert isinstance(x, (float, int)), type(x)
|
||||
assert isinstance(y, (float, int)), type(y)
|
||||
|
||||
|
@ -1,12 +1,6 @@
|
||||
# isort:skip_file
|
||||
|
||||
from . import (
|
||||
checkpoint,
|
||||
decode,
|
||||
dist,
|
||||
env,
|
||||
utils
|
||||
)
|
||||
from . import checkpoint, decode, dist, env, utils
|
||||
|
||||
from .byte_utils import (
|
||||
byte_decode,
|
||||
|
@ -227,7 +227,6 @@ class ContextGraph:
|
||||
filename: Optional[str] = "",
|
||||
symbol_table: Optional[Dict[int, str]] = None,
|
||||
) -> "Digraph": # noqa
|
||||
|
||||
"""Visualize a ContextGraph via graphviz.
|
||||
|
||||
Render ContextGraph as an image via graphviz, and return the Digraph object;
|
||||
|
@ -23,6 +23,7 @@ from typing import Optional, Tuple, List
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class TensorDiagnosticOptions(object):
|
||||
"""Options object for tensor diagnostics:
|
||||
|
||||
@ -77,11 +78,11 @@ def get_tensor_stats(
|
||||
elif stats_type == "abs":
|
||||
x = x.abs()
|
||||
elif stats_type == "rms":
|
||||
x = x ** 2
|
||||
x = x**2
|
||||
elif stats_type == "positive":
|
||||
x = (x > 0).to(dtype=torch.float)
|
||||
else:
|
||||
assert stats_type in [ "value", "max", "min" ]
|
||||
assert stats_type in ["value", "max", "min"]
|
||||
|
||||
sum_dims = [d for d in range(x.ndim) if d != dim]
|
||||
if len(sum_dims) > 0:
|
||||
@ -121,10 +122,10 @@ class TensorDiagnostic(object):
|
||||
self.class_name = None # will assign in accumulate()
|
||||
|
||||
self.stats = None # we'll later assign a list to self.stats.
|
||||
# It's a list of dicts, indexed by dim (i.e. by the
|
||||
# axis of the tensor). The dicts, in turn, are
|
||||
# indexed by `stats-type` which are strings in
|
||||
# ["abs", "max", "min", "positive", "value", "rms"].
|
||||
# It's a list of dicts, indexed by dim (i.e. by the
|
||||
# axis of the tensor). The dicts, in turn, are
|
||||
# indexed by `stats-type` which are strings in
|
||||
# ["abs", "max", "min", "positive", "value", "rms"].
|
||||
|
||||
# scalar_stats contains some analysis of the activations and gradients,
|
||||
self.scalar_stats = None
|
||||
@ -139,7 +140,6 @@ class TensorDiagnostic(object):
|
||||
# only adding a new element to the list if there was a different dim.
|
||||
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
||||
|
||||
|
||||
def accumulate(self, x, class_name: Optional[str] = None):
|
||||
"""
|
||||
Accumulate tensors.
|
||||
@ -193,17 +193,12 @@ class TensorDiagnostic(object):
|
||||
done = True
|
||||
break
|
||||
if not done:
|
||||
if (
|
||||
this_dim_stats[stats_type] != []
|
||||
and stats_type == "eigs"
|
||||
):
|
||||
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
|
||||
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
|
||||
# don't accumulat "eigs" stats type, it uses too much memory
|
||||
this_dim_stats[stats_type] = None
|
||||
else:
|
||||
this_dim_stats[stats_type].append(
|
||||
TensorAndCount(stats, count)
|
||||
)
|
||||
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
|
||||
|
||||
def print_diagnostics(self):
|
||||
"""Print diagnostics for each dimension of the tensor."""
|
||||
@ -220,8 +215,11 @@ class TensorDiagnostic(object):
|
||||
for r, v in zip(rms_stats_list, value_stats_list):
|
||||
stddev_stats_list.append(
|
||||
# r.count and v.count should be the same, but we don't check this.
|
||||
TensorAndCount(r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
|
||||
r.count))
|
||||
TensorAndCount(
|
||||
r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
|
||||
r.count,
|
||||
)
|
||||
)
|
||||
this_dim_stats["stddev"] = stddev_stats_list
|
||||
|
||||
for stats_type, stats_list in this_dim_stats.items():
|
||||
@ -232,7 +230,6 @@ class TensorDiagnostic(object):
|
||||
assert stats_type == "eigs"
|
||||
continue
|
||||
|
||||
|
||||
def get_count(count):
|
||||
return 1 if stats_type in ["max", "min"] else count
|
||||
|
||||
@ -250,22 +247,20 @@ class TensorDiagnostic(object):
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print(
|
||||
"Error getting eigenvalues, trying another method."
|
||||
)
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
eigs, _ = torch.eig(stats)
|
||||
stats = eigs.norm(dim=1).sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
|
||||
if stats_type in [ "rms", "stddev" ]:
|
||||
if stats_type in ["rms", "stddev"]:
|
||||
# we stored the square; after aggregation we need to take sqrt.
|
||||
stats = stats.sqrt()
|
||||
|
||||
# if `summarize` we print percentiles of the stats; else,
|
||||
# we print out individual elements.
|
||||
summarize = (
|
||||
len(stats_list) > 1
|
||||
) or self.opts.dim_is_summarized(stats.numel())
|
||||
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
|
||||
stats.numel()
|
||||
)
|
||||
if summarize: # usually `summarize` will be true
|
||||
# print out percentiles.
|
||||
stats = stats.sort()[0]
|
||||
@ -282,15 +277,15 @@ class TensorDiagnostic(object):
|
||||
ans = stats.tolist()
|
||||
ans = ["%.2g" % x for x in ans]
|
||||
ans = "[" + " ".join(ans) + "]"
|
||||
if stats_type in [ "value", "rms", "stddev", "eigs" ]:
|
||||
if stats_type in ["value", "rms", "stddev", "eigs"]:
|
||||
# This norm is useful because it is strictly less than the largest
|
||||
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
||||
# speaking in an approximate way, how much of that largest eigenvalue
|
||||
# can be attributed to the mean of the distribution.
|
||||
norm = (stats ** 2).sum().sqrt().item()
|
||||
norm = (stats**2).sum().sqrt().item()
|
||||
ans += f", norm={norm:.2g}"
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
rms = (stats**2).mean().sqrt().item()
|
||||
ans += f", mean={mean:.3g}, rms={rms:.3g}"
|
||||
|
||||
# OK, "ans" contains the actual stats, e.g.
|
||||
@ -298,11 +293,11 @@ class TensorDiagnostic(object):
|
||||
|
||||
sizes = [x.tensor.shape[0] for x in stats_list]
|
||||
size_str = (
|
||||
f"{sizes[0]}"
|
||||
if len(sizes) == 1
|
||||
else f"{min(sizes)}..{max(sizes)}"
|
||||
f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
|
||||
)
|
||||
maybe_class_name = (
|
||||
f" type={self.class_name}," if self.class_name is not None else ""
|
||||
)
|
||||
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
||||
print(
|
||||
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
|
||||
)
|
||||
@ -330,7 +325,6 @@ class ScalarDiagnostic(object):
|
||||
self.sum_gradsq = None
|
||||
self.sum_abs_grad = None
|
||||
|
||||
|
||||
def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
|
||||
"""
|
||||
Called in forward pass.
|
||||
@ -347,8 +341,10 @@ class ScalarDiagnostic(object):
|
||||
|
||||
limit = 10
|
||||
if len(self.saved_inputs) > limit:
|
||||
print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. "
|
||||
f" Will not accumulate scalar stats.")
|
||||
print(
|
||||
f"ERROR: forward pass called for this module over {limit} times with no backward pass. "
|
||||
f" Will not accumulate scalar stats."
|
||||
)
|
||||
self.is_ok = False
|
||||
return
|
||||
self.saved_inputs.append(x)
|
||||
@ -359,11 +355,15 @@ class ScalarDiagnostic(object):
|
||||
if self.is_forward_pass:
|
||||
self.is_forward_pass = False
|
||||
|
||||
last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
|
||||
last_shape = (
|
||||
"n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
|
||||
)
|
||||
if len(self.saved_inputs) == 0 or grad.shape != last_shape:
|
||||
print(f"ERROR: shape mismatch or no forward activation present when backward "
|
||||
f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}"
|
||||
f", shape-of-last-saved-input={last_shape}")
|
||||
print(
|
||||
f"ERROR: shape mismatch or no forward activation present when backward "
|
||||
f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}"
|
||||
f", shape-of-last-saved-input={last_shape}"
|
||||
)
|
||||
self.is_ok = False
|
||||
return
|
||||
|
||||
@ -384,11 +384,19 @@ class ScalarDiagnostic(object):
|
||||
self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
|
||||
|
||||
# integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
|
||||
self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device)
|
||||
self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
||||
self.counts = torch.zeros(
|
||||
2 * num_ticks_per_side, dtype=torch.long, device=x.device
|
||||
)
|
||||
self.sum_grad = torch.zeros(
|
||||
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
||||
)
|
||||
# sum_gradsq is for getting error bars.
|
||||
self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
||||
self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
||||
self.sum_gradsq = torch.zeros(
|
||||
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
||||
)
|
||||
self.sum_abs_grad = torch.zeros(
|
||||
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
||||
)
|
||||
|
||||
# this will round down.
|
||||
x = (x / self.tick_scale).to(torch.long)
|
||||
@ -397,20 +405,21 @@ class ScalarDiagnostic(object):
|
||||
|
||||
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
|
||||
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
|
||||
self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double))
|
||||
self.sum_gradsq.index_add_(
|
||||
dim=0, index=x, source=(grad * grad).to(torch.double)
|
||||
)
|
||||
self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
|
||||
|
||||
|
||||
def print_diagnostics(self):
|
||||
"""Print diagnostics."""
|
||||
if self.is_ok is False or self.counts is None:
|
||||
print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
|
||||
return
|
||||
|
||||
counts = self.counts.to('cpu')
|
||||
sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32)
|
||||
sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32)
|
||||
sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32)
|
||||
counts = self.counts.to("cpu")
|
||||
sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32)
|
||||
sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32)
|
||||
sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32)
|
||||
|
||||
counts_cumsum = counts.cumsum(dim=0)
|
||||
counts_tot = counts_cumsum[-1]
|
||||
@ -433,19 +442,22 @@ class ScalarDiagnostic(object):
|
||||
bin_abs_grad = torch.zeros(num_bins)
|
||||
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
|
||||
|
||||
avg_grad = (bin_grad / bin_counts)
|
||||
avg_grad = bin_grad / bin_counts
|
||||
avg_grad_stddev = (bin_gradsq / bin_counts).sqrt()
|
||||
|
||||
bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
|
||||
bin_boundary_counts = (
|
||||
torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
|
||||
)
|
||||
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
|
||||
# boundaries are the "x" values between the bins, e.g. corresponding to the
|
||||
# locations of percentiles of the distribution.
|
||||
num_ticks_per_side = counts.numel() // 2
|
||||
bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
|
||||
|
||||
|
||||
bin_grad = bin_grad / (bin_counts + 1)
|
||||
bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation.
|
||||
bin_conf_interval = bin_gradsq.sqrt() / (
|
||||
bin_counts + 1
|
||||
) # consider this a standard deviation.
|
||||
# bin_grad / bin_abs_grad will give us a sense for how important in a practical sense,
|
||||
# the gradients are.
|
||||
bin_abs_grad = bin_abs_grad / (bin_counts + 1)
|
||||
@ -458,8 +470,9 @@ class ScalarDiagnostic(object):
|
||||
x = "[" + " ".join(x) + "]"
|
||||
return x
|
||||
|
||||
|
||||
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
||||
maybe_class_name = (
|
||||
f" type={self.class_name}," if self.class_name is not None else ""
|
||||
)
|
||||
|
||||
print(
|
||||
f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, "
|
||||
@ -467,7 +480,6 @@ class ScalarDiagnostic(object):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||
|
||||
@ -485,9 +497,8 @@ class ModelDiagnostic(object):
|
||||
self.opts = opts
|
||||
self.diagnostics = dict()
|
||||
|
||||
|
||||
def __getitem__(self, name: str):
|
||||
T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic
|
||||
T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic
|
||||
if name not in self.diagnostics:
|
||||
self.diagnostics[name] = T(self.opts, name)
|
||||
return self.diagnostics[name]
|
||||
@ -502,18 +513,19 @@ def get_class_name(module: nn.Module):
|
||||
ans = type(module).__name__
|
||||
# we put the below in try blocks in case anyone is using a different version of these modules that
|
||||
# might have different member names.
|
||||
if ans == 'Balancer' or ans == 'ActivationBalancer':
|
||||
if ans == "Balancer" or ans == "ActivationBalancer":
|
||||
try:
|
||||
ans += f'[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]'
|
||||
ans += f"[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]"
|
||||
except:
|
||||
pass
|
||||
elif ans == 'AbsValuePenalizer':
|
||||
elif ans == "AbsValuePenalizer":
|
||||
try:
|
||||
ans += f'[{module.limit}]'
|
||||
ans += f"[{module.limit}]"
|
||||
except:
|
||||
pass
|
||||
return ans
|
||||
|
||||
|
||||
def attach_diagnostics(
|
||||
model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
|
||||
) -> ModelDiagnostic:
|
||||
@ -538,73 +550,85 @@ def attach_diagnostics(
|
||||
if name == "":
|
||||
name = "<top-level>"
|
||||
|
||||
|
||||
|
||||
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
||||
# capture the variables, ensures that we use the current values.
|
||||
# (this matters for `name`, since the variable gets overwritten).
|
||||
# These closures don't really capture by value, only by
|
||||
# "the final value the variable got in the function" :-(
|
||||
def forward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
|
||||
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.output"].accumulate(_output,
|
||||
class_name=get_class_name(_module))
|
||||
if isinstance(_output, Tensor) and _output.dtype in (
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.float64,
|
||||
):
|
||||
_model_diagnostic[f"{_name}.output"].accumulate(
|
||||
_output, class_name=get_class_name(_module)
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
|
||||
o, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
def backward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
|
||||
class_name=get_class_name(_module))
|
||||
if isinstance(_output, Tensor) and _output.dtype in (
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.float64,
|
||||
):
|
||||
_model_diagnostic[f"{_name}.grad"].accumulate(
|
||||
_output, class_name=get_class_name(_module)
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
|
||||
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
|
||||
o, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
|
||||
if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh"]:
|
||||
if type(module).__name__ in [
|
||||
"Sigmoid",
|
||||
"Tanh",
|
||||
"ReLU",
|
||||
"TanSwish",
|
||||
"Swish",
|
||||
"DoubleSwish",
|
||||
"Swoosh",
|
||||
]:
|
||||
# For these specific module types, accumulate some additional diagnostics
|
||||
# that can help us improve the activation function. These require a lot of memory,
|
||||
# to save the forward activations, so limit this to some select classes.
|
||||
# Note: this will not work correctly for all model types.
|
||||
def scalar_forward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_input, tuple):
|
||||
_input, = _input
|
||||
(_input,) = _input
|
||||
assert isinstance(_input, Tensor)
|
||||
_model_diagnostic[f"{_name}.scalar"].accumulate_input(_input,
|
||||
class_name=get_class_name(_module))
|
||||
_model_diagnostic[f"{_name}.scalar"].accumulate_input(
|
||||
_input, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
def scalar_backward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_output, tuple):
|
||||
_output, = _output
|
||||
(_output,) = _output
|
||||
assert isinstance(_output, Tensor)
|
||||
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
|
||||
|
||||
module.register_forward_hook(scalar_forward_hook)
|
||||
module.register_backward_hook(scalar_backward_hook)
|
||||
|
||||
|
||||
|
||||
for name, parameter in model.named_parameters():
|
||||
|
||||
def param_backward_hook(
|
||||
|
@ -70,25 +70,17 @@ class FlopsProfiler(object):
|
||||
module_flop_count.append([])
|
||||
|
||||
if not hasattr(module, "__pre_hook_handle__"):
|
||||
module.__pre_hook_handle__ = module.register_forward_pre_hook(
|
||||
pre_hook
|
||||
)
|
||||
module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
|
||||
|
||||
def post_hook(module, input, output):
|
||||
if module_flop_count:
|
||||
module.__flops__ += sum(
|
||||
[elem[1] for elem in module_flop_count[-1]]
|
||||
)
|
||||
module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]])
|
||||
module_flop_count.pop()
|
||||
|
||||
if not hasattr(module, "__post_hook_handle__"):
|
||||
module.__post_hook_handle__ = module.register_forward_hook(
|
||||
post_hook
|
||||
)
|
||||
module.__post_hook_handle__ = module.register_forward_hook(post_hook)
|
||||
|
||||
self.model.apply(
|
||||
partial(register_module_hooks, ignore_list=ignore_list)
|
||||
)
|
||||
self.model.apply(partial(register_module_hooks, ignore_list=ignore_list))
|
||||
self.started = True
|
||||
self.func_patched = True
|
||||
|
||||
@ -194,9 +186,7 @@ def _prelu_flops_compute(input: Tensor, weight: Tensor):
|
||||
return input.numel()
|
||||
|
||||
|
||||
def _elu_flops_compute(
|
||||
input: Tensor, alpha: float = 1.0, inplace: bool = False
|
||||
):
|
||||
def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False):
|
||||
return input.numel()
|
||||
|
||||
|
||||
@ -259,9 +249,7 @@ def _conv_flops_compute(
|
||||
output_dims.append(output_dim)
|
||||
|
||||
filters_per_channel = out_channels // groups
|
||||
conv_per_position_macs = (
|
||||
int(_prod(kernel_dims)) * in_channels * filters_per_channel
|
||||
)
|
||||
conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel
|
||||
active_elements_count = batch_size * int(_prod(output_dims))
|
||||
overall_conv_macs = conv_per_position_macs * active_elements_count
|
||||
overall_conv_flops = 2 * overall_conv_macs
|
||||
@ -297,7 +285,6 @@ def _conv_trans_flops_compute(
|
||||
|
||||
output_dims = []
|
||||
for idx, input_dim in enumerate(input_dims):
|
||||
|
||||
output_dim = (
|
||||
input_dim
|
||||
+ 2 * paddings[idx]
|
||||
@ -310,9 +297,7 @@ def _conv_trans_flops_compute(
|
||||
dilations = dilation if type(dilation) is tuple else (dilation, dilation)
|
||||
|
||||
filters_per_channel = out_channels // groups
|
||||
conv_per_position_macs = (
|
||||
int(_prod(kernel_dims)) * in_channels * filters_per_channel
|
||||
)
|
||||
conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel
|
||||
active_elements_count = batch_size * int(_prod(input_dims))
|
||||
overall_conv_macs = conv_per_position_macs * active_elements_count
|
||||
overall_conv_flops = 2 * overall_conv_macs
|
||||
@ -389,9 +374,7 @@ def _upsample_flops_compute(input, **kwargs):
|
||||
else:
|
||||
return int(size), 0
|
||||
scale_factor = kwargs.get("scale_factor", None)
|
||||
assert (
|
||||
scale_factor is not None
|
||||
), "either size or scale_factor should be defined"
|
||||
assert scale_factor is not None, "either size or scale_factor should be defined"
|
||||
flops = input.numel()
|
||||
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
|
||||
flops * int(_prod(scale_factor))
|
||||
@ -593,12 +576,8 @@ def _patch_functionals():
|
||||
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)
|
||||
|
||||
# swoosh functions in k2
|
||||
k2.swoosh_l_forward = wrapFunc(
|
||||
k2.swoosh_l_forward, _k2_swoosh_flops_compute
|
||||
)
|
||||
k2.swoosh_r_forward = wrapFunc(
|
||||
k2.swoosh_r_forward, _k2_swoosh_flops_compute
|
||||
)
|
||||
k2.swoosh_l_forward = wrapFunc(k2.swoosh_l_forward, _k2_swoosh_flops_compute)
|
||||
k2.swoosh_r_forward = wrapFunc(k2.swoosh_r_forward, _k2_swoosh_flops_compute)
|
||||
k2.swoosh_l = wrapFunc(k2.swoosh_l, _k2_swoosh_flops_compute)
|
||||
k2.swoosh_r = wrapFunc(k2.swoosh_r, _k2_swoosh_flops_compute)
|
||||
|
||||
@ -612,9 +591,7 @@ def _patch_tensor_methods():
|
||||
torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute)
|
||||
|
||||
torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute)
|
||||
torch.Tensor.addmm = wrapFunc(
|
||||
torch.Tensor.addmm, _tensor_addmm_flops_compute
|
||||
)
|
||||
torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute)
|
||||
|
||||
torch.mul = wrapFunc(torch.mul, _mul_flops_compute)
|
||||
torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute)
|
||||
@ -631,14 +608,10 @@ def _patch_tensor_methods():
|
||||
|
||||
torch.tanh = wrapFunc(torch.tanh, _tanh_flops_compute)
|
||||
|
||||
torch.Tensor.softmax = wrapFunc(
|
||||
torch.Tensor.softmax, _softmax_flops_compute
|
||||
)
|
||||
torch.Tensor.softmax = wrapFunc(torch.Tensor.softmax, _softmax_flops_compute)
|
||||
|
||||
torch.sigmoid = wrapFunc(torch.sigmoid, _sigmoid_flops_compute)
|
||||
torch.Tensor.sigmoid = wrapFunc(
|
||||
torch.Tensor.sigmoid, _sigmoid_flops_compute
|
||||
)
|
||||
torch.Tensor.sigmoid = wrapFunc(torch.Tensor.sigmoid, _sigmoid_flops_compute)
|
||||
|
||||
|
||||
def _reload_functionals():
|
||||
@ -732,15 +705,11 @@ def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
|
||||
flops += rnn_module.hidden_size * 4
|
||||
# two hadamard _product and add for C state
|
||||
flops += (
|
||||
rnn_module.hidden_size
|
||||
+ rnn_module.hidden_size
|
||||
+ rnn_module.hidden_size
|
||||
rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
|
||||
)
|
||||
# final hadamard
|
||||
flops += (
|
||||
rnn_module.hidden_size
|
||||
+ rnn_module.hidden_size
|
||||
+ rnn_module.hidden_size
|
||||
rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
|
||||
)
|
||||
return flops
|
||||
|
||||
|
@ -112,7 +112,6 @@ def main():
|
||||
for torch_v, onnx_v in zip(
|
||||
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
|
||||
):
|
||||
|
||||
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
|
||||
torch_v.shape,
|
||||
onnx_v.shape,
|
||||
|
@ -463,7 +463,6 @@ def train_one_epoch(
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
@ -225,7 +225,6 @@ class NgramCounts:
|
||||
for n in range(0, self.ngram_order - 1):
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
|
||||
n_star_star = 0
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
n_star_star += len(counts_for_hist.word_to_context[w])
|
||||
@ -424,7 +423,6 @@ class NgramCounts:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
ngram_counts = NgramCounts(args.ngram_order)
|
||||
|
||||
if args.text is None:
|
||||
|
@ -103,7 +103,6 @@ class TransformerLM(torch.nn.Module):
|
||||
return nll_loss
|
||||
|
||||
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
|
||||
|
||||
bs = x.size(0)
|
||||
|
||||
state = None
|
||||
|
@ -20,6 +20,7 @@ kaldialign==0.7.1
|
||||
sentencepiece==0.1.96
|
||||
tensorboard==2.8.0
|
||||
typeguard==2.13.3
|
||||
black==22.3.0
|
||||
multi_quantization
|
||||
|
||||
onnx
|
||||
|
@ -5,3 +5,4 @@ sentencepiece>=0.1.96
|
||||
tensorboard
|
||||
typeguard
|
||||
dill
|
||||
black==22.3.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user