mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 06:25:27 +00:00
Make 2 projections..
This commit is contained in:
parent
f75d40c725
commit
c67ae0f3a1
@ -23,7 +23,8 @@ class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.output_linear = ScaledLinear(input_dim, output_dim)
|
||||
self.output_linear = ScaledLinear(input_dim, output_dim,
|
||||
initial_speed=0.5)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
|
||||
@ -61,9 +61,15 @@ class Transducer(nn.Module):
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
# could perhaps separate this into 2 linear projections, one
|
||||
# for lm and one for am.
|
||||
self.simple_joiner = ScaledLinear(embedding_dim, vocab_size)
|
||||
self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size)
|
||||
self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size)
|
||||
with torch.no_grad():
|
||||
# Initialize the two projections to be the same; this will be
|
||||
# convenient for the real joiner, which adds the endcoder
|
||||
# (acoustic-model/am) and decoder (language-model/lm) embeddings
|
||||
self.simple_lm_proj.weight[:] = self.simple_am_proj.weight
|
||||
self.simple_lm_proj.bias[:] = self.simple_am_proj.bias
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -140,8 +146,8 @@ class Transducer(nn.Module):
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=self.simple_joiner(decoder_out),
|
||||
am=self.simple_joiner(encoder_out),
|
||||
lm=self.simple_lm_proj(decoder_out),
|
||||
am=self.simple_am_proj(encoder_out),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user