mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add more diagnostics to debug gradient scale problems
This commit is contained in:
parent
476fb9e9f3
commit
1d2fe8e3c2
@ -851,7 +851,8 @@ def train_one_epoch(
|
|||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}"
|
f"lr: {cur_lr:.2e}, " +
|
||||||
|
(f"grad_scale: {scaler.scale}" if params.use_fp16 else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -865,6 +866,12 @@ def train_one_epoch(
|
|||||||
tot_loss.write_summary(
|
tot_loss.write_summary(
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
if params.use_fp16:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/grad_scale", scaler.scale, params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
|
|||||||
@ -82,11 +82,18 @@ def get_tensor_stats(
|
|||||||
elif stats_type == "positive":
|
elif stats_type == "positive":
|
||||||
x = (x > 0).to(dtype=torch.float)
|
x = (x > 0).to(dtype=torch.float)
|
||||||
else:
|
else:
|
||||||
assert stats_type == "value"
|
assert stats_type in [ "value", "max", "min" ]
|
||||||
|
|
||||||
sum_dims = [d for d in range(x.ndim) if d != dim]
|
sum_dims = [d for d in range(x.ndim) if d != dim]
|
||||||
if len(sum_dims) > 0:
|
if len(sum_dims) > 0:
|
||||||
x = torch.sum(x, dim=sum_dims)
|
if stats_type == "max":
|
||||||
|
for dim in reversed(sum_dims):
|
||||||
|
x = torch.max(x, dim=dim)[0]
|
||||||
|
elif stats_type == "min":
|
||||||
|
for dim in reversed(sum_dims):
|
||||||
|
x = torch.min(x, dim=dim)[0]
|
||||||
|
else:
|
||||||
|
x = torch.sum(x, dim=sum_dims)
|
||||||
x = x.flatten()
|
x = x.flatten()
|
||||||
return x, count
|
return x, count
|
||||||
|
|
||||||
@ -117,7 +124,7 @@ class TensorDiagnostic(object):
|
|||||||
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
||||||
|
|
||||||
# the keys into self.stats[dim] are strings, whose values can be
|
# the keys into self.stats[dim] are strings, whose values can be
|
||||||
# "abs", "value", "positive", "rms", "value".
|
# "abs", "max", "min" ,"value", "positive", "rms", "value".
|
||||||
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
|
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
|
||||||
# containing a tensor and its associated count (which is the sum of the other dims
|
# containing a tensor and its associated count (which is the sum of the other dims
|
||||||
# that we aggregated over, e.g. the number of frames and/or batch elements and/or
|
# that we aggregated over, e.g. the number of frames and/or batch elements and/or
|
||||||
@ -149,11 +156,11 @@ class TensorDiagnostic(object):
|
|||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
this_dim_stats = self.stats[dim]
|
this_dim_stats = self.stats[dim]
|
||||||
if ndim > 1:
|
if ndim > 1:
|
||||||
stats_types = ["abs", "positive", "value", "rms"]
|
stats_types = ["abs", "max", "min", "positive", "value", "rms"]
|
||||||
if x.shape[dim] <= self.opts.max_eig_dim:
|
if x.shape[dim] <= self.opts.max_eig_dim:
|
||||||
stats_types.append("eigs")
|
stats_types.append("eigs")
|
||||||
else:
|
else:
|
||||||
stats_types = ["value", "abs"]
|
stats_types = ["value", "abs", "max", "min"]
|
||||||
|
|
||||||
for stats_type in stats_types:
|
for stats_type in stats_types:
|
||||||
stats, count = get_tensor_stats(x, dim, stats_type)
|
stats, count = get_tensor_stats(x, dim, stats_type)
|
||||||
@ -168,7 +175,12 @@ class TensorDiagnostic(object):
|
|||||||
continue
|
continue
|
||||||
for s in this_dim_stats[stats_type]:
|
for s in this_dim_stats[stats_type]:
|
||||||
if s.tensor.shape == stats.shape:
|
if s.tensor.shape == stats.shape:
|
||||||
s.tensor += stats
|
if stats_type == "max":
|
||||||
|
s.tensor = torch.maximum(s.tensor, stats)
|
||||||
|
elif stats_type == "min":
|
||||||
|
s.tensor = torch.minimum(s.tensor, stats)
|
||||||
|
else:
|
||||||
|
s.tensor += stats
|
||||||
s.count += count
|
s.count += count
|
||||||
done = True
|
done = True
|
||||||
break
|
break
|
||||||
@ -199,13 +211,17 @@ class TensorDiagnostic(object):
|
|||||||
assert stats_type == "eigs"
|
assert stats_type == "eigs"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def get_count(count):
|
||||||
|
return 1 if stats_type in ["max", "min"] else count
|
||||||
|
|
||||||
if len(stats_list) == 1:
|
if len(stats_list) == 1:
|
||||||
stats = stats_list[0].tensor / stats_list[0].count
|
stats = stats_list[0].tensor / get_count(stats_list[0].count)
|
||||||
else:
|
else:
|
||||||
# a dimension that has variable size in different nnet
|
# a dimension that has variable size in different nnet
|
||||||
# forwards, e.g. a time dimension in an ASR model.
|
# forwards, e.g. a time dimension in an ASR model.
|
||||||
stats = torch.cat(
|
stats = torch.cat(
|
||||||
[x.tensor / x.count for x in stats_list], dim=0
|
[x.tensor / get_count(x.count) for x in stats_list], dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
if stats_type == "eigs":
|
if stats_type == "eigs":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user