Fix to diagnostics.py (fix for max being doubled), from scaled_adam_exp446; small cosmetic fixes.

This commit is contained in:
Daniel Povey 2022-11-21 14:00:55 +08:00
parent 4e21db07f6
commit 9fe6add587
2 changed files with 7 additions and 6 deletions

View File

@ -1419,7 +1419,7 @@ class NonlinAttentionModule(nn.Module):
self.to_value = nn.Linear(channels, channels, bias=True) self.to_value = nn.Linear(channels, channels, bias=True)
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule # deriv_balancer corresponds to deriv_balancer2 in ConvolutionModule
self.deriv_balancer = ActivationBalancer( self.deriv_balancer = ActivationBalancer(
channels, channel_dim=-1, channels, channel_dim=-1,
min_positive=0.05, max_positive=1.0, min_positive=0.05, max_positive=1.0,

View File

@ -16,7 +16,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
@ -24,7 +23,6 @@ from typing import Optional, Tuple, List
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
class TensorDiagnosticOptions(object): class TensorDiagnosticOptions(object):
"""Options object for tensor diagnostics: """Options object for tensor diagnostics:
@ -60,7 +58,8 @@ def get_tensor_stats(
"abs" -> take abs() before summing "abs" -> take abs() before summing
"positive" -> take (x > 0) before summing "positive" -> take (x > 0) before summing
"rms" -> square before summing, we'll take sqrt later "rms" -> square before summing, we'll take sqrt later
"value -> just sum x itself "value" -> just sum x itself
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
Returns: Returns:
stats: a Tensor of shape (x.shape[dim],). stats: a Tensor of shape (x.shape[dim],).
count: an integer saying how many items were counted in each element count: an integer saying how many items were counted in each element
@ -94,7 +93,7 @@ def get_tensor_stats(
x = torch.min(x, dim=dim)[0] x = torch.min(x, dim=dim)[0]
else: else:
x = torch.sum(x, dim=sum_dims) x = torch.sum(x, dim=sum_dims)
x = x.flatten() x = x.flatten().clone()
return x, count return x, count
@ -177,9 +176,11 @@ class TensorDiagnostic(object):
if s.tensor.shape == stats.shape: if s.tensor.shape == stats.shape:
if stats_type == "max": if stats_type == "max":
s.tensor = torch.maximum(s.tensor, stats) s.tensor = torch.maximum(s.tensor, stats)
elif stats_type == "min": elif stats_type == "min":
s.tensor = torch.minimum(s.tensor, stats) s.tensor = torch.minimum(s.tensor, stats)
else: else:
assert stats_type != "max"
s.tensor += stats s.tensor += stats
s.count += count s.count += count
done = True done = True
@ -204,7 +205,7 @@ class TensorDiagnostic(object):
return return
for dim, this_dim_stats in enumerate(self.stats): for dim, this_dim_stats in enumerate(self.stats):
for stats_type, stats_list in this_dim_stats.items(): for stats_type, stats_list in this_dim_stats.items():
# stats_type could be "rms", "value", "abs", "eigs", "positive". # stats_type could be "rms", "value", "abs", "eigs", "positive", "min" or "max".
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor # "stats_list" could be a list of TensorAndCount (one list per distinct tensor
# shape of the stats), or None # shape of the stats), or None
if stats_list is None: if stats_list is None: