diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 1a2324775..088ef14cb 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -31,32 +31,33 @@ class TensorDiagnosticOptions(object): -def get_sum_abs_stats(x: Tensor, dim: int, +def get_tensor_stats(x: Tensor, dim: int, stats_type: str) -> Tuple[Tensor, int]: """ - Returns the sum-of-absolute-value of this Tensor, for each - index into the specified axis/dim of the tensor. + Returns the specified transformation of the Tensor (either x or x.abs() + or (x > 0), summed over all but the index `dim`. + Args: x: Tensor, tensor to be analyzed dim: dimension with 0 <= dim < x.ndim - stats_type: either "mean-abs" in which case the stats represent the - mean absolute value, or "pos-ratio" in which case the - stats represent the proportion of positive values (actually: - the tensor is count of positive values, count is the count of - all values). - Returns (sum_abs, count) - where sum_abs is a Tensor of shape (x.shape[dim],), and the count + stats_type: + "mean-abs" or "abs-value" -> take abs() before summing + "pos-ratio" -> take (x > 0) before summing + "value -> just sum x itself + Returns (stats, count) + where stats is a Tensor of shape (x.shape[dim],), and the count is an integer saying how many items were counted in each element - of sum_abs. + of stats. """ - if stats_type == "mean-abs": + if stats_type == "mean-abs" or stats_type == "abs-value": x = x.abs() - else: - assert stats_type == "pos-ratio" + elif stats_type == "pos-ratio": x = (x > 0).to(dtype=torch.float) + else: + assert stats_type == "value" orig_numel = x.numel() sum_dims = [ d for d in range(x.ndim) if d != dim ] - if len(sum_dims) != 0: + if len(sum_dims) > 0: x = torch.sum(x, dim=sum_dims) count = orig_numel // x.numel() x = x.flatten() @@ -80,7 +81,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], see the code. """ # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] + stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats = [ x[0] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ] if sizes_same: @@ -115,9 +116,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions): + ndim = tensors[0].ndim + # options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the + # normal case. + stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ] - for stats_type in options.stats_types(): - # stats_type will be "mean-abs" or "pos-ratio". + for stats_type in stats_types: sizes = [ x.shape[dim] for x in tensors ] sizes_same = all([ x == sizes[0] for x in sizes ]) s = get_diagnostics_for_dim(dim, tensors,