mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix to diagnostics.py (fix for max being doubled), from scaled_adam_exp446; small cosmetic fixes.
This commit is contained in:
parent
4e21db07f6
commit
9fe6add587
@ -1419,7 +1419,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
self.to_value = nn.Linear(channels, channels, bias=True)
|
||||
|
||||
|
||||
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule
|
||||
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionModule
|
||||
self.deriv_balancer = ActivationBalancer(
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.05, max_positive=1.0,
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List
|
||||
@ -24,7 +23,6 @@ from typing import Optional, Tuple, List
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class TensorDiagnosticOptions(object):
|
||||
"""Options object for tensor diagnostics:
|
||||
|
||||
@ -60,7 +58,8 @@ def get_tensor_stats(
|
||||
"abs" -> take abs() before summing
|
||||
"positive" -> take (x > 0) before summing
|
||||
"rms" -> square before summing, we'll take sqrt later
|
||||
"value -> just sum x itself
|
||||
"value" -> just sum x itself
|
||||
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
|
||||
Returns:
|
||||
stats: a Tensor of shape (x.shape[dim],).
|
||||
count: an integer saying how many items were counted in each element
|
||||
@ -94,7 +93,7 @@ def get_tensor_stats(
|
||||
x = torch.min(x, dim=dim)[0]
|
||||
else:
|
||||
x = torch.sum(x, dim=sum_dims)
|
||||
x = x.flatten()
|
||||
x = x.flatten().clone()
|
||||
return x, count
|
||||
|
||||
|
||||
@ -177,9 +176,11 @@ class TensorDiagnostic(object):
|
||||
if s.tensor.shape == stats.shape:
|
||||
if stats_type == "max":
|
||||
s.tensor = torch.maximum(s.tensor, stats)
|
||||
|
||||
elif stats_type == "min":
|
||||
s.tensor = torch.minimum(s.tensor, stats)
|
||||
else:
|
||||
assert stats_type != "max"
|
||||
s.tensor += stats
|
||||
s.count += count
|
||||
done = True
|
||||
@ -204,7 +205,7 @@ class TensorDiagnostic(object):
|
||||
return
|
||||
for dim, this_dim_stats in enumerate(self.stats):
|
||||
for stats_type, stats_list in this_dim_stats.items():
|
||||
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
||||
# stats_type could be "rms", "value", "abs", "eigs", "positive", "min" or "max".
|
||||
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
|
||||
# shape of the stats), or None
|
||||
if stats_list is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user