mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Adjust joiner and simple_lm/simple_am projections to account for larger activation dims
This commit is contained in:
parent
03e1f7dc01
commit
a8282bb6d7
@ -16,6 +16,9 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ScaledLinear
|
||||
)
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
@ -28,8 +31,8 @@ class Joiner(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
|
||||
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -22,7 +22,11 @@ import random
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from scaling import penalize_abs_values_gt
|
||||
from scaling import (
|
||||
penalize_abs_values_gt,
|
||||
ScaledLinear
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -64,10 +68,13 @@ class Transducer(nn.Module):
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = nn.Linear(
|
||||
encoder_dim, vocab_size,
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_scale=0.25,
|
||||
)
|
||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim, vocab_size, initial_scale=0.25,
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user