Merge branch 'merge_refactor_param_cov_norank1_iter_batch_max4.0_pow0.5_fix2r_lrupdate200_2k_ns' into merge2_refactor_max4.0_pow0.5_200_1k_ma3.0

This commit is contained in:
Daniel Povey 2022-07-17 15:32:43 +08:00
commit 3857a87b47
3 changed files with 33 additions and 29 deletions

View File

@ -466,7 +466,7 @@ class RelPositionMultiheadAttention(nn.Module):
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = StructuredLinear((embed_dim,), (3, embed_dim), bias=True)
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0,
max_positive=1.0, max_abs=10.0)
@ -544,8 +544,8 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.get_weight(),
self.in_proj.get_bias(),
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
@ -881,9 +881,9 @@ class ConvolutionModule(nn.Module):
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = StructuredConv1d(
(channels,),
(2, channels),
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,

View File

@ -619,8 +619,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# (where the stats permit).
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
if random.random() < 0.001:
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}")
if random.random() < 0.01:
skip = 10 if size < 20 else 1
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}")
# scale shape: (batch_size, 1, size, 1, 1)
cur_p *= scale
@ -755,7 +756,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
U, S, V = _svd(N_grad_cov)
if random.random() < 0.001:
if random.random() < 0.01:
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
@ -1820,16 +1821,6 @@ def _test_eve_cain():
logging.info(f"input_magnitudes = {input_magnitudes}")
logging.info(f"output_magnitudes = {output_magnitudes}")
def stddev(x):
return ((x-x.mean())**2).mean().sqrt()
logging.info(f"Un-normalized input col magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=0).sqrt().log())}")
logging.info(f"Normalized input col magnitudes log-stddev: {stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())}")
logging.info(f"Un-normalized 0-output row magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=1).sqrt().log())}")
logging.info("Un-normalized 2-input col magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=0).sqrt().log())}")
logging.info("Un-normalized 2-output row magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=1).sqrt().log())}")
logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")
def _test_svd():
device = 'cuda'

View File

@ -105,12 +105,14 @@ class TensorDiagnostic(object):
opts:
Options object.
name:
The tensor name.
The name associated with this diagnostics object, will probably be {module_name}.X
where X is "output" or "grad", or {parameter_name}.Y where Y is param_value or param_grad.
"""
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name
self.opts = opts
self.name = name
self.class_name = None # will assign in accumulate()
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
@ -124,8 +126,13 @@ class TensorDiagnostic(object):
# only adding a new element to the list if there was a different dim.
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
def accumulate(self, x):
"""Accumulate tensors."""
def accumulate(self, x, class_name: Optional[str] = None):
"""
Accumulate tensors.
"""
if class_name is not None:
self.class_name = class_name
if isinstance(x, Tuple):
x = x[0]
if not isinstance(x, Tensor):
@ -240,7 +247,7 @@ class TensorDiagnostic(object):
ans += f", norm={norm:.2g}"
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
ans += f", mean={mean:.3g}, rms={rms:.3g}"
# OK, "ans" contains the actual stats, e.g.
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
@ -251,8 +258,9 @@ class TensorDiagnostic(object):
if len(sizes) == 1
else f"{min(sizes)}..{max(sizes)}"
)
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
print(
f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}"
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
)
@ -316,20 +324,25 @@ def attach_diagnostics(
def forward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.output"].accumulate(_output)
_model_diagnostic[f"{_name}.output"].accumulate(_output,
class_name=type(_module).__name__)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=type(_module).__name__)
def backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
class_name=type(_module).__name__)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o)
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=type(_module).__name__)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)