mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use penalize_abs_values_gt, not ActivationBalancer.
This commit is contained in:
parent
7a55cac346
commit
466176eeff
@ -18,10 +18,11 @@
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from scaling import ActivationBalancer
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -68,18 +69,6 @@ class Transducer(nn.Module):
|
||||
)
|
||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||
|
||||
|
||||
self.lm_balancer = ActivationBalancer(vocab_size,
|
||||
channel_dim=-1,
|
||||
min_positive=0, max_positive=1,
|
||||
min_abs=0, max_abs=50.0,
|
||||
max_factor=0.1)
|
||||
self.am_balancer = ActivationBalancer(vocab_size,
|
||||
channel_dim=-1,
|
||||
min_positive=0, max_positive=1,
|
||||
min_abs=0, max_abs=30.0,
|
||||
max_factor=0.1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -150,8 +139,13 @@ class Transducer(nn.Module):
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.lm_balancer(self.simple_lm_proj(decoder_out))
|
||||
am = self.am_balancer(self.simple_am_proj(encoder_out))
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
if self.training and random.random() < 0.25:
|
||||
lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
if self.training and random.random() < 0.25:
|
||||
am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user