diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 38394564e..6d8256c26 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1419,7 +1419,7 @@ class NonlinAttentionModule(nn.Module): self.to_value = nn.Linear(channels, channels, bias=True) - # deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule + # deriv_balancer corresponds to deriv_balancer2 in ConvolutionModule self.deriv_balancer = ActivationBalancer( channels, channel_dim=-1, min_positive=0.05, max_positive=1.0, diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index b075aceac..a380cdb52 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -16,7 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import random from dataclasses import dataclass from typing import Optional, Tuple, List @@ -24,7 +23,6 @@ from typing import Optional, Tuple, List import torch from torch import Tensor, nn - class TensorDiagnosticOptions(object): """Options object for tensor diagnostics: @@ -60,7 +58,8 @@ def get_tensor_stats( "abs" -> take abs() before summing "positive" -> take (x > 0) before summing "rms" -> square before summing, we'll take sqrt later - "value -> just sum x itself + "value" -> just sum x itself + "max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing Returns: stats: a Tensor of shape (x.shape[dim],). count: an integer saying how many items were counted in each element @@ -94,7 +93,7 @@ def get_tensor_stats( x = torch.min(x, dim=dim)[0] else: x = torch.sum(x, dim=sum_dims) - x = x.flatten() + x = x.flatten().clone() return x, count @@ -177,9 +176,11 @@ class TensorDiagnostic(object): if s.tensor.shape == stats.shape: if stats_type == "max": s.tensor = torch.maximum(s.tensor, stats) + elif stats_type == "min": s.tensor = torch.minimum(s.tensor, stats) else: + assert stats_type != "max" s.tensor += stats s.count += count done = True @@ -204,7 +205,7 @@ class TensorDiagnostic(object): return for dim, this_dim_stats in enumerate(self.stats): for stats_type, stats_list in this_dim_stats.items(): - # stats_type could be "rms", "value", "abs", "eigs", "positive". + # stats_type could be "rms", "value", "abs", "eigs", "positive", "min" or "max". # "stats_list" could be a list of TensorAndCount (one list per distinct tensor # shape of the stats), or None if stats_list is None: