Adding activation balancers after simple_am_prob and simple_lm_prob

This commit is contained in:
Daniel Povey 2022-10-22 19:37:35 +08:00
parent 11886dc4f6
commit 1908123af9

View File

@ -21,6 +21,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
from scaling import ActivationBalancer
class Transducer(nn.Module):
@ -67,6 +68,16 @@ class Transducer(nn.Module):
)
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
self.lm_balancer = ActivationBalancer(vocab_size,
channel_dim=-1,
min_positive=0, max_positive=1,
min_abs=0, max_abs=50.0)
self.am_balancer = ActivationBalancer(vocab_size,
channel_dim=-1,
min_positive=0, max_positive=1,
min_abs=0, max_abs=50.0)
def forward(
self,
x: torch.Tensor,
@ -137,8 +148,8 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
lm = self.lm_balancer(self.simple_lm_proj(decoder_out))
am = self.am_balancer(self.simple_am_proj(encoder_out))
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(