diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 88b78e92c..f5d3cd3a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -851,7 +851,8 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler.scale}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -865,6 +866,12 @@ def train_one_epoch( tot_loss.write_summary( tb_writer, "train/tot_", params.batch_idx_train ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", scaler.scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index c26e37de4..ce35d2bdf 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -82,11 +82,18 @@ def get_tensor_stats( elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: - assert stats_type == "value" + assert stats_type in [ "value", "max", "min" ] sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: - x = torch.sum(x, dim=sum_dims) + if stats_type == "max": + for dim in reversed(sum_dims): + x = torch.max(x, dim=dim)[0] + elif stats_type == "min": + for dim in reversed(sum_dims): + x = torch.min(x, dim=dim)[0] + else: + x = torch.sum(x, dim=sum_dims) x = x.flatten() return x, count @@ -117,7 +124,7 @@ class TensorDiagnostic(object): self.stats = None # we'll later assign a list to this data member. It's a list of dict. # the keys into self.stats[dim] are strings, whose values can be - # "abs", "value", "positive", "rms", "value". + # "abs", "max", "min" ,"value", "positive", "rms", "value". # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount, # containing a tensor and its associated count (which is the sum of the other dims # that we aggregated over, e.g. the number of frames and/or batch elements and/or @@ -149,11 +156,11 @@ class TensorDiagnostic(object): for dim in range(ndim): this_dim_stats = self.stats[dim] if ndim > 1: - stats_types = ["abs", "positive", "value", "rms"] + stats_types = ["abs", "max", "min", "positive", "value", "rms"] if x.shape[dim] <= self.opts.max_eig_dim: stats_types.append("eigs") else: - stats_types = ["value", "abs"] + stats_types = ["value", "abs", "max", "min"] for stats_type in stats_types: stats, count = get_tensor_stats(x, dim, stats_type) @@ -168,7 +175,12 @@ class TensorDiagnostic(object): continue for s in this_dim_stats[stats_type]: if s.tensor.shape == stats.shape: - s.tensor += stats + if stats_type == "max": + s.tensor = torch.maximum(s.tensor, stats) + elif stats_type == "min": + s.tensor = torch.minimum(s.tensor, stats) + else: + s.tensor += stats s.count += count done = True break @@ -199,13 +211,17 @@ class TensorDiagnostic(object): assert stats_type == "eigs" continue + + def get_count(count): + return 1 if stats_type in ["max", "min"] else count + if len(stats_list) == 1: - stats = stats_list[0].tensor / stats_list[0].count + stats = stats_list[0].tensor / get_count(stats_list[0].count) else: # a dimension that has variable size in different nnet # forwards, e.g. a time dimension in an ASR model. stats = torch.cat( - [x.tensor / x.count for x in stats_list], dim=0 + [x.tensor / get_count(x.count) for x in stats_list], dim=0 ) if stats_type == "eigs":