Add more diagnostics to debug gradient scale problems

This commit is contained in:
Daniel Povey 2022-10-22 12:49:29 +08:00
parent 476fb9e9f3
commit 1d2fe8e3c2
2 changed files with 32 additions and 9 deletions

View File

@ -851,7 +851,8 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}" f"lr: {cur_lr:.2e}, " +
(f"grad_scale: {scaler.scale}" if params.use_fp16 else "")
) )
if tb_writer is not None: if tb_writer is not None:
@ -865,6 +866,12 @@ def train_one_epoch(
tot_loss.write_summary( tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train tb_writer, "train/tot_", params.batch_idx_train
) )
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", scaler.scale, params.batch_idx_train
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss") logging.info("Computing validation loss")

View File

@ -82,11 +82,18 @@ def get_tensor_stats(
elif stats_type == "positive": elif stats_type == "positive":
x = (x > 0).to(dtype=torch.float) x = (x > 0).to(dtype=torch.float)
else: else:
assert stats_type == "value" assert stats_type in [ "value", "max", "min" ]
sum_dims = [d for d in range(x.ndim) if d != dim] sum_dims = [d for d in range(x.ndim) if d != dim]
if len(sum_dims) > 0: if len(sum_dims) > 0:
x = torch.sum(x, dim=sum_dims) if stats_type == "max":
for dim in reversed(sum_dims):
x = torch.max(x, dim=dim)[0]
elif stats_type == "min":
for dim in reversed(sum_dims):
x = torch.min(x, dim=dim)[0]
else:
x = torch.sum(x, dim=sum_dims)
x = x.flatten() x = x.flatten()
return x, count return x, count
@ -117,7 +124,7 @@ class TensorDiagnostic(object):
self.stats = None # we'll later assign a list to this data member. It's a list of dict. self.stats = None # we'll later assign a list to this data member. It's a list of dict.
# the keys into self.stats[dim] are strings, whose values can be # the keys into self.stats[dim] are strings, whose values can be
# "abs", "value", "positive", "rms", "value". # "abs", "max", "min" ,"value", "positive", "rms", "value".
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount, # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
# containing a tensor and its associated count (which is the sum of the other dims # containing a tensor and its associated count (which is the sum of the other dims
# that we aggregated over, e.g. the number of frames and/or batch elements and/or # that we aggregated over, e.g. the number of frames and/or batch elements and/or
@ -149,11 +156,11 @@ class TensorDiagnostic(object):
for dim in range(ndim): for dim in range(ndim):
this_dim_stats = self.stats[dim] this_dim_stats = self.stats[dim]
if ndim > 1: if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"] stats_types = ["abs", "max", "min", "positive", "value", "rms"]
if x.shape[dim] <= self.opts.max_eig_dim: if x.shape[dim] <= self.opts.max_eig_dim:
stats_types.append("eigs") stats_types.append("eigs")
else: else:
stats_types = ["value", "abs"] stats_types = ["value", "abs", "max", "min"]
for stats_type in stats_types: for stats_type in stats_types:
stats, count = get_tensor_stats(x, dim, stats_type) stats, count = get_tensor_stats(x, dim, stats_type)
@ -168,7 +175,12 @@ class TensorDiagnostic(object):
continue continue
for s in this_dim_stats[stats_type]: for s in this_dim_stats[stats_type]:
if s.tensor.shape == stats.shape: if s.tensor.shape == stats.shape:
s.tensor += stats if stats_type == "max":
s.tensor = torch.maximum(s.tensor, stats)
elif stats_type == "min":
s.tensor = torch.minimum(s.tensor, stats)
else:
s.tensor += stats
s.count += count s.count += count
done = True done = True
break break
@ -199,13 +211,17 @@ class TensorDiagnostic(object):
assert stats_type == "eigs" assert stats_type == "eigs"
continue continue
def get_count(count):
return 1 if stats_type in ["max", "min"] else count
if len(stats_list) == 1: if len(stats_list) == 1:
stats = stats_list[0].tensor / stats_list[0].count stats = stats_list[0].tensor / get_count(stats_list[0].count)
else: else:
# a dimension that has variable size in different nnet # a dimension that has variable size in different nnet
# forwards, e.g. a time dimension in an ASR model. # forwards, e.g. a time dimension in an ASR model.
stats = torch.cat( stats = torch.cat(
[x.tensor / x.count for x in stats_list], dim=0 [x.tensor / get_count(x.count) for x in stats_list], dim=0
) )
if stats_type == "eigs": if stats_type == "eigs":