mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove unused code LearnedScale.
This commit is contained in:
parent
cf450908c6
commit
00841f0f49
@ -29,7 +29,6 @@ from s import (
|
||||
DoubleSwish,
|
||||
ScaledConv1d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
LearnedScale,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@ -326,30 +326,6 @@ def ScaledConv1d(*args,
|
||||
return ans
|
||||
|
||||
|
||||
class LearnedScale(torch.nn.Module):
|
||||
"""
|
||||
Module that learns a scale dependent on some kind of mask that is typically going to be 0 or 1
|
||||
in training. The scale will be 1.0 if the mask is 1.0, but may be a different (learned) value
|
||||
if the mask value is not 1.0.
|
||||
|
||||
The idea is that if we have some kind of feature mask that would always be 1.0 in
|
||||
test mode but might sometimes be 0.0 in training mode, we might want the multiply
|
||||
the remaining features by a value dependent on this mask.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LearnedScale, self).__init__()
|
||||
self.alpha = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
mask: Tensor):
|
||||
"""
|
||||
Mask should either be a number (probably 1.0) or a tensors that broadcasts with x.
|
||||
"""
|
||||
if self.training and mask is 1.0:
|
||||
return x
|
||||
return x * (1.0 + self.alpha * (1.0 - mask))
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user