mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Adding activation balancers after simple_am_prob and simple_lm_prob
This commit is contained in:
parent
11886dc4f6
commit
1908123af9
@ -21,6 +21,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
from scaling import ActivationBalancer
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -67,6 +68,16 @@ class Transducer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
|
|
||||||
|
self.lm_balancer = ActivationBalancer(vocab_size,
|
||||||
|
channel_dim=-1,
|
||||||
|
min_positive=0, max_positive=1,
|
||||||
|
min_abs=0, max_abs=50.0)
|
||||||
|
self.am_balancer = ActivationBalancer(vocab_size,
|
||||||
|
channel_dim=-1,
|
||||||
|
min_positive=0, max_positive=1,
|
||||||
|
min_abs=0, max_abs=50.0)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -137,8 +148,8 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.lm_balancer(self.simple_lm_proj(decoder_out))
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.am_balancer(self.simple_am_proj(encoder_out))
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user