mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Adjust joiner and simple_lm/simple_am projections to account for larger activation dims
This commit is contained in:
parent
03e1f7dc01
commit
a8282bb6d7
@ -16,6 +16,9 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from scaling import (
|
||||||
|
ScaledLinear
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
@ -28,8 +31,8 @@ class Joiner(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
|
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||||
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
|
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -22,7 +22,11 @@ import random
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
from scaling import penalize_abs_values_gt
|
from scaling import (
|
||||||
|
penalize_abs_values_gt,
|
||||||
|
ScaledLinear
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -64,10 +68,13 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
self.simple_am_proj = nn.Linear(
|
self.simple_am_proj = ScaledLinear(
|
||||||
encoder_dim, vocab_size,
|
encoder_dim, vocab_size, initial_scale=0.25,
|
||||||
)
|
)
|
||||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(
|
||||||
|
decoder_dim, vocab_size, initial_scale=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user