diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 9839a3fb9..88cafeb88 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -296,7 +296,7 @@ class SoftmaxFunction(torch.autograd.Function): dim = ctx.dim if dim < 0: - dim = dim + ans.dim + dim = dim + ans.ndim split_dim = 0 if dim != 0 else 1 # split_dim is the dimension we split up ans on. num_split = min(8, ans.shape[split_dim]) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 51e816105..4ae0538ad 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -498,6 +498,22 @@ class ModelDiagnostic(object): self.diagnostics[k].print_diagnostics() +def get_class_name(module: nn.Module): + ans = type(module).__name__ + # we put the below in try blocks in case anyone is using a different version of these modules that + # might have different member names. + if ans == 'Balancer' or ans == 'ActivationBalancer': + try: + ans += f'[{module.min_positive},{module.max_positive},{module.min_abs},{module.max_abs}]' + except: + pass + elif ans == 'AbsValuePenalizer': + try: + ans += f'[{module.limit}]' + except: + pass + return ans + def attach_diagnostics( model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None ) -> ModelDiagnostic: @@ -537,12 +553,12 @@ def attach_diagnostics( if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.output"].accumulate(_output, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) elif isinstance(_output, tuple): for i, o in enumerate(_output): if o.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) def backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name @@ -551,12 +567,12 @@ def attach_diagnostics( _output = _output[0] if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.grad"].accumulate(_output, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) elif isinstance(_output, tuple): for i, o in enumerate(_output): if o.dtype in ( torch.float32, torch.float16, torch.float64 ): _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) module.register_forward_hook(forward_hook) @@ -574,7 +590,7 @@ def attach_diagnostics( _input, = _input assert isinstance(_input, Tensor) _model_diagnostic[f"{_name}.scalar"].accumulate_input(_input, - class_name=type(_module).__name__) + class_name=get_class_name(_module)) def scalar_backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name