Use penalize_abs_values_gt, not ActivationBalancer.
This commit is contained in:
parent
7a55cac346
commit
466176eeff
@ -18,10 +18,11 @@
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import random
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
from scaling import ActivationBalancer
|
from scaling import penalize_abs_values_gt
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -68,18 +69,6 @@ class Transducer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
|
|
||||||
self.lm_balancer = ActivationBalancer(vocab_size,
|
|
||||||
channel_dim=-1,
|
|
||||||
min_positive=0, max_positive=1,
|
|
||||||
min_abs=0, max_abs=50.0,
|
|
||||||
max_factor=0.1)
|
|
||||||
self.am_balancer = ActivationBalancer(vocab_size,
|
|
||||||
channel_dim=-1,
|
|
||||||
min_positive=0, max_positive=1,
|
|
||||||
min_abs=0, max_abs=30.0,
|
|
||||||
max_factor=0.1)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -150,8 +139,13 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
lm = self.lm_balancer(self.simple_lm_proj(decoder_out))
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.am_balancer(self.simple_am_proj(encoder_out))
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
|
if self.training and random.random() < 0.25:
|
||||||
|
lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||||
|
if self.training and random.random() < 0.25:
|
||||||
|
am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user