mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 06:25:27 +00:00
Make 2 projections..
This commit is contained in:
parent
f75d40c725
commit
c67ae0f3a1
@ -23,7 +23,8 @@ class Joiner(nn.Module):
|
|||||||
def __init__(self, input_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, output_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.output_linear = ScaledLinear(input_dim, output_dim)
|
self.output_linear = ScaledLinear(input_dim, output_dim,
|
||||||
|
initial_speed=0.5)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
|
|||||||
@ -61,9 +61,15 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
# could perhaps separate this into 2 linear projections, one
|
self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size)
|
||||||
# for lm and one for am.
|
self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size)
|
||||||
self.simple_joiner = ScaledLinear(embedding_dim, vocab_size)
|
with torch.no_grad():
|
||||||
|
# Initialize the two projections to be the same; this will be
|
||||||
|
# convenient for the real joiner, which adds the endcoder
|
||||||
|
# (acoustic-model/am) and decoder (language-model/lm) embeddings
|
||||||
|
self.simple_lm_proj.weight[:] = self.simple_am_proj.weight
|
||||||
|
self.simple_lm_proj.bias[:] = self.simple_am_proj.bias
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -140,8 +146,8 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=self.simple_joiner(decoder_out),
|
lm=self.simple_lm_proj(decoder_out),
|
||||||
am=self.simple_joiner(encoder_out),
|
am=self.simple_am_proj(encoder_out),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user