Adjust joiner and simple_lm/simple_am projections to account for larger activation dims

This commit is contained in:
Daniel Povey 2022-12-29 12:52:11 +08:00
parent 03e1f7dc01
commit a8282bb6d7
2 changed files with 16 additions and 6 deletions

View File

@ -16,6 +16,9 @@
import torch
import torch.nn as nn
from scaling import (
ScaledLinear
)
class Joiner(nn.Module):
@ -28,8 +31,8 @@ class Joiner(nn.Module):
):
super().__init__()
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(

View File

@ -22,7 +22,11 @@ import random
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
from scaling import penalize_abs_values_gt
from scaling import (
penalize_abs_values_gt,
ScaledLinear
)
class Transducer(nn.Module):
@ -64,10 +68,13 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = nn.Linear(
encoder_dim, vocab_size,
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25,
)
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25,
)
def forward(
self,