mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Make training more efficient, avoid redoing some projections.
This commit is contained in:
parent
99e9d6c4b8
commit
a5bbcd7b71
@ -35,7 +35,8 @@ class Joiner(nn.Module):
|
|||||||
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor,
|
||||||
|
project_input: bool = True
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -43,13 +44,20 @@ class Joiner(nn.Module):
|
|||||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||||
|
project_input:
|
||||||
|
If true, apply input projections encoder_proj and decoder_proj.
|
||||||
|
If this is false, it is the user's responsibility to do this
|
||||||
|
manually.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||||
|
|
||||||
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
|
else:
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
logit = self.output_linear(torch.tanh(logit))
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
|
||||||
|
@ -164,11 +164,17 @@ class Transducer(nn.Module):
|
|||||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
am=encoder_out, lm=decoder_out, ranges=ranges
|
am=self.joiner.encoder_proj(encoder_out),
|
||||||
|
lm=self.joiner.decoder_proj(decoder_out),
|
||||||
|
ranges=ranges
|
||||||
)
|
)
|
||||||
|
|
||||||
# logits : [B, T, prune_range, vocab_size]
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
logits = self.joiner(am_pruned, lm_pruned)
|
|
||||||
|
# project_input=False since we applied the decoder's input projections
|
||||||
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
|
logits = self.joiner(am_pruned, lm_pruned,
|
||||||
|
project_input=False)
|
||||||
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user