Make 2 projections..

This commit is contained in:
Daniel Povey 2022-03-31 13:02:40 +08:00
parent f75d40c725
commit c67ae0f3a1
2 changed files with 13 additions and 6 deletions

View File

@ -23,7 +23,8 @@ class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.output_linear = ScaledLinear(input_dim, output_dim)
self.output_linear = ScaledLinear(input_dim, output_dim,
initial_speed=0.5)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor

View File

@ -61,9 +61,15 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joiner = joiner
# could perhaps separate this into 2 linear projections, one
# for lm and one for am.
self.simple_joiner = ScaledLinear(embedding_dim, vocab_size)
self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size)
self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size)
with torch.no_grad():
# Initialize the two projections to be the same; this will be
# convenient for the real joiner, which adds the endcoder
# (acoustic-model/am) and decoder (language-model/lm) embeddings
self.simple_lm_proj.weight[:] = self.simple_am_proj.weight
self.simple_lm_proj.bias[:] = self.simple_am_proj.bias
def forward(
self,
@ -140,8 +146,8 @@ class Transducer(nn.Module):
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=self.simple_joiner(decoder_out),
am=self.simple_joiner(encoder_out),
lm=self.simple_lm_proj(decoder_out),
am=self.simple_am_proj(encoder_out),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,