diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index c77bd41da..517e7b5e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -466,7 +466,7 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = StructuredLinear((embed_dim,), (3, embed_dim), bias=True) + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0) self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0, max_positive=1.0, max_abs=10.0) @@ -544,8 +544,8 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb, self.embed_dim, self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), + self.in_proj.weight, + self.in_proj.bias, self.dropout, self.out_proj.weight, self.out_proj.bias, @@ -881,9 +881,9 @@ class ConvolutionModule(nn.Module): # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = StructuredConv1d( - (channels,), - (2, channels), + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, kernel_size=1, stride=1, padding=0, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 898104ffb..5c20361be 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -619,8 +619,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # (where the stats permit). scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt() - if random.random() < 0.001: - logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}") + if random.random() < 0.01: + skip = 10 if size < 20 else 1 + logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}") # scale shape: (batch_size, 1, size, 1, 1) cur_p *= scale @@ -755,7 +756,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3))) N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric U, S, V = _svd(N_grad_cov) - if random.random() < 0.001: + if random.random() < 0.01: logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion " f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}") @@ -1820,16 +1821,6 @@ def _test_eve_cain(): logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") - def stddev(x): - return ((x-x.mean())**2).mean().sqrt() - logging.info(f"Un-normalized input col magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=0).sqrt().log())}") - logging.info(f"Normalized input col magnitudes log-stddev: {stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())}") - - logging.info(f"Un-normalized 0-output row magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=1).sqrt().log())}") - logging.info("Un-normalized 2-input col magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=0).sqrt().log())}") - logging.info("Un-normalized 2-output row magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=1).sqrt().log())}") - - logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}") def _test_svd(): device = 'cuda' diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 4850308d9..01bf552cc 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -105,12 +105,14 @@ class TensorDiagnostic(object): opts: Options object. name: - The tensor name. + The name associated with this diagnostics object, will probably be {module_name}.X + where X is "output" or "grad", or {parameter_name}.Y where Y is param_value or param_grad. """ def __init__(self, opts: TensorDiagnosticOptions, name: str): - self.name = name self.opts = opts + self.name = name + self.class_name = None # will assign in accumulate() self.stats = None # we'll later assign a list to this data member. It's a list of dict. @@ -124,8 +126,13 @@ class TensorDiagnostic(object): # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. - def accumulate(self, x): - """Accumulate tensors.""" + + def accumulate(self, x, class_name: Optional[str] = None): + """ + Accumulate tensors. + """ + if class_name is not None: + self.class_name = class_name if isinstance(x, Tuple): x = x[0] if not isinstance(x, Tensor): @@ -240,7 +247,7 @@ class TensorDiagnostic(object): ans += f", norm={norm:.2g}" mean = stats.mean().item() rms = (stats ** 2).mean().sqrt().item() - ans += f", mean={mean:.2g}, rms={rms:.2g}" + ans += f", mean={mean:.3g}, rms={rms:.3g}" # OK, "ans" contains the actual stats, e.g. # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5" @@ -251,8 +258,9 @@ class TensorDiagnostic(object): if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" ) + maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" print( - f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}" + f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) @@ -316,20 +324,25 @@ def attach_diagnostics( def forward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): + if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.output"].accumulate(_output) + _model_diagnostic[f"{_name}.output"].accumulate(_output, + class_name=type(_module).__name__) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, + class_name=type(_module).__name__) def backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.grad"].accumulate(_output) + _model_diagnostic[f"{_name}.grad"].accumulate(_output, + class_name=type(_module).__name__) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o) + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, + class_name=type(_module).__name__) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook)