mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'merge_refactor_param_cov_norank1_iter_batch_max4.0_pow0.5_fix2r_lrupdate200_2k_ns' into merge2_refactor_max4.0_pow0.5_200_1k_ma3.0
This commit is contained in:
commit
3857a87b47
@ -466,7 +466,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.head_dim * num_heads == self.embed_dim
|
self.head_dim * num_heads == self.embed_dim
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = StructuredLinear((embed_dim,), (3, embed_dim), bias=True)
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
|
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
|
||||||
self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0,
|
self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0,
|
||||||
max_positive=1.0, max_abs=10.0)
|
max_positive=1.0, max_abs=10.0)
|
||||||
@ -544,8 +544,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.in_proj.get_weight(),
|
self.in_proj.weight,
|
||||||
self.in_proj.get_bias(),
|
self.in_proj.bias,
|
||||||
self.dropout,
|
self.dropout,
|
||||||
self.out_proj.weight,
|
self.out_proj.weight,
|
||||||
self.out_proj.bias,
|
self.out_proj.bias,
|
||||||
@ -881,9 +881,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
self.pointwise_conv1 = StructuredConv1d(
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
(channels,),
|
channels,
|
||||||
(2, channels),
|
2 * channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
|
|||||||
@ -619,8 +619,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# (where the stats permit).
|
# (where the stats permit).
|
||||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||||
|
|
||||||
if random.random() < 0.001:
|
if random.random() < 0.01:
|
||||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}")
|
skip = 10 if size < 20 else 1
|
||||||
|
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}")
|
||||||
|
|
||||||
# scale shape: (batch_size, 1, size, 1, 1)
|
# scale shape: (batch_size, 1, size, 1, 1)
|
||||||
cur_p *= scale
|
cur_p *= scale
|
||||||
@ -755,7 +756,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
|
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
|
||||||
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
|
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
|
||||||
U, S, V = _svd(N_grad_cov)
|
U, S, V = _svd(N_grad_cov)
|
||||||
if random.random() < 0.001:
|
if random.random() < 0.01:
|
||||||
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
|
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
|
||||||
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
|
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
|
||||||
|
|
||||||
@ -1820,16 +1821,6 @@ def _test_eve_cain():
|
|||||||
logging.info(f"input_magnitudes = {input_magnitudes}")
|
logging.info(f"input_magnitudes = {input_magnitudes}")
|
||||||
logging.info(f"output_magnitudes = {output_magnitudes}")
|
logging.info(f"output_magnitudes = {output_magnitudes}")
|
||||||
|
|
||||||
def stddev(x):
|
|
||||||
return ((x-x.mean())**2).mean().sqrt()
|
|
||||||
logging.info(f"Un-normalized input col magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=0).sqrt().log())}")
|
|
||||||
logging.info(f"Normalized input col magnitudes log-stddev: {stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())}")
|
|
||||||
|
|
||||||
logging.info(f"Un-normalized 0-output row magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=1).sqrt().log())}")
|
|
||||||
logging.info("Un-normalized 2-input col magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=0).sqrt().log())}")
|
|
||||||
logging.info("Un-normalized 2-output row magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=1).sqrt().log())}")
|
|
||||||
|
|
||||||
logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")
|
|
||||||
|
|
||||||
def _test_svd():
|
def _test_svd():
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
|
|||||||
@ -105,12 +105,14 @@ class TensorDiagnostic(object):
|
|||||||
opts:
|
opts:
|
||||||
Options object.
|
Options object.
|
||||||
name:
|
name:
|
||||||
The tensor name.
|
The name associated with this diagnostics object, will probably be {module_name}.X
|
||||||
|
where X is "output" or "grad", or {parameter_name}.Y where Y is param_value or param_grad.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||||
self.name = name
|
|
||||||
self.opts = opts
|
self.opts = opts
|
||||||
|
self.name = name
|
||||||
|
self.class_name = None # will assign in accumulate()
|
||||||
|
|
||||||
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
||||||
|
|
||||||
@ -124,8 +126,13 @@ class TensorDiagnostic(object):
|
|||||||
# only adding a new element to the list if there was a different dim.
|
# only adding a new element to the list if there was a different dim.
|
||||||
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
||||||
|
|
||||||
def accumulate(self, x):
|
|
||||||
"""Accumulate tensors."""
|
def accumulate(self, x, class_name: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Accumulate tensors.
|
||||||
|
"""
|
||||||
|
if class_name is not None:
|
||||||
|
self.class_name = class_name
|
||||||
if isinstance(x, Tuple):
|
if isinstance(x, Tuple):
|
||||||
x = x[0]
|
x = x[0]
|
||||||
if not isinstance(x, Tensor):
|
if not isinstance(x, Tensor):
|
||||||
@ -240,7 +247,7 @@ class TensorDiagnostic(object):
|
|||||||
ans += f", norm={norm:.2g}"
|
ans += f", norm={norm:.2g}"
|
||||||
mean = stats.mean().item()
|
mean = stats.mean().item()
|
||||||
rms = (stats ** 2).mean().sqrt().item()
|
rms = (stats ** 2).mean().sqrt().item()
|
||||||
ans += f", mean={mean:.2g}, rms={rms:.2g}"
|
ans += f", mean={mean:.3g}, rms={rms:.3g}"
|
||||||
|
|
||||||
# OK, "ans" contains the actual stats, e.g.
|
# OK, "ans" contains the actual stats, e.g.
|
||||||
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
|
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
|
||||||
@ -251,8 +258,9 @@ class TensorDiagnostic(object):
|
|||||||
if len(sizes) == 1
|
if len(sizes) == 1
|
||||||
else f"{min(sizes)}..{max(sizes)}"
|
else f"{min(sizes)}..{max(sizes)}"
|
||||||
)
|
)
|
||||||
|
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
||||||
print(
|
print(
|
||||||
f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}"
|
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -316,20 +324,25 @@ def attach_diagnostics(
|
|||||||
def forward_hook(
|
def forward_hook(
|
||||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||||
):
|
):
|
||||||
|
|
||||||
if isinstance(_output, Tensor):
|
if isinstance(_output, Tensor):
|
||||||
_model_diagnostic[f"{_name}.output"].accumulate(_output)
|
_model_diagnostic[f"{_name}.output"].accumulate(_output,
|
||||||
|
class_name=type(_module).__name__)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
|
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||||
|
class_name=type(_module).__name__)
|
||||||
|
|
||||||
def backward_hook(
|
def backward_hook(
|
||||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||||
):
|
):
|
||||||
if isinstance(_output, Tensor):
|
if isinstance(_output, Tensor):
|
||||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
|
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
|
||||||
|
class_name=type(_module).__name__)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o)
|
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||||
|
class_name=type(_module).__name__)
|
||||||
|
|
||||||
module.register_forward_hook(forward_hook)
|
module.register_forward_hook(forward_hook)
|
||||||
module.register_backward_hook(backward_hook)
|
module.register_backward_hook(backward_hook)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user