mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix to diagnostics.py (fix for max being doubled), from scaled_adam_exp446; small cosmetic fixes.
This commit is contained in:
parent
4e21db07f6
commit
9fe6add587
@ -1419,7 +1419,7 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
self.to_value = nn.Linear(channels, channels, bias=True)
|
self.to_value = nn.Linear(channels, channels, bias=True)
|
||||||
|
|
||||||
|
|
||||||
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule
|
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionModule
|
||||||
self.deriv_balancer = ActivationBalancer(
|
self.deriv_balancer = ActivationBalancer(
|
||||||
channels, channel_dim=-1,
|
channels, channel_dim=-1,
|
||||||
min_positive=0.05, max_positive=1.0,
|
min_positive=0.05, max_positive=1.0,
|
||||||
|
|||||||
@ -16,7 +16,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
@ -24,7 +23,6 @@ from typing import Optional, Tuple, List
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class TensorDiagnosticOptions(object):
|
class TensorDiagnosticOptions(object):
|
||||||
"""Options object for tensor diagnostics:
|
"""Options object for tensor diagnostics:
|
||||||
|
|
||||||
@ -60,7 +58,8 @@ def get_tensor_stats(
|
|||||||
"abs" -> take abs() before summing
|
"abs" -> take abs() before summing
|
||||||
"positive" -> take (x > 0) before summing
|
"positive" -> take (x > 0) before summing
|
||||||
"rms" -> square before summing, we'll take sqrt later
|
"rms" -> square before summing, we'll take sqrt later
|
||||||
"value -> just sum x itself
|
"value" -> just sum x itself
|
||||||
|
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
|
||||||
Returns:
|
Returns:
|
||||||
stats: a Tensor of shape (x.shape[dim],).
|
stats: a Tensor of shape (x.shape[dim],).
|
||||||
count: an integer saying how many items were counted in each element
|
count: an integer saying how many items were counted in each element
|
||||||
@ -94,7 +93,7 @@ def get_tensor_stats(
|
|||||||
x = torch.min(x, dim=dim)[0]
|
x = torch.min(x, dim=dim)[0]
|
||||||
else:
|
else:
|
||||||
x = torch.sum(x, dim=sum_dims)
|
x = torch.sum(x, dim=sum_dims)
|
||||||
x = x.flatten()
|
x = x.flatten().clone()
|
||||||
return x, count
|
return x, count
|
||||||
|
|
||||||
|
|
||||||
@ -177,9 +176,11 @@ class TensorDiagnostic(object):
|
|||||||
if s.tensor.shape == stats.shape:
|
if s.tensor.shape == stats.shape:
|
||||||
if stats_type == "max":
|
if stats_type == "max":
|
||||||
s.tensor = torch.maximum(s.tensor, stats)
|
s.tensor = torch.maximum(s.tensor, stats)
|
||||||
|
|
||||||
elif stats_type == "min":
|
elif stats_type == "min":
|
||||||
s.tensor = torch.minimum(s.tensor, stats)
|
s.tensor = torch.minimum(s.tensor, stats)
|
||||||
else:
|
else:
|
||||||
|
assert stats_type != "max"
|
||||||
s.tensor += stats
|
s.tensor += stats
|
||||||
s.count += count
|
s.count += count
|
||||||
done = True
|
done = True
|
||||||
@ -204,7 +205,7 @@ class TensorDiagnostic(object):
|
|||||||
return
|
return
|
||||||
for dim, this_dim_stats in enumerate(self.stats):
|
for dim, this_dim_stats in enumerate(self.stats):
|
||||||
for stats_type, stats_list in this_dim_stats.items():
|
for stats_type, stats_list in this_dim_stats.items():
|
||||||
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
# stats_type could be "rms", "value", "abs", "eigs", "positive", "min" or "max".
|
||||||
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
|
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
|
||||||
# shape of the stats), or None
|
# shape of the stats), or None
|
||||||
if stats_list is None:
|
if stats_list is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user