Make training more efficient, avoid redoing some projections.

This commit is contained in:
Daniel Povey 2022-04-04 14:10:38 +08:00
parent 99e9d6c4b8
commit a5bbcd7b71
2 changed files with 18 additions and 4 deletions

View File

@ -35,7 +35,8 @@ class Joiner(nn.Module):
self.output_linear = ScaledLinear(joiner_dim, vocab_size) self.output_linear = ScaledLinear(joiner_dim, vocab_size)
def forward( def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor self, encoder_out: torch.Tensor, decoder_out: torch.Tensor,
project_input: bool = True
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -43,13 +44,20 @@ class Joiner(nn.Module):
Output from the encoder. Its shape is (N, T, s_range, C). Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out: decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C). Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns: Returns:
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape[:-1] == decoder_out.shape[:-1] assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit)) logit = self.output_linear(torch.tanh(logit))

View File

@ -164,11 +164,17 @@ class Transducer(nn.Module):
# am_pruned : [B, T, prune_range, encoder_dim] # am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning( am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges
) )
# logits : [B, T, prune_range, vocab_size] # logits : [B, T, prune_range, vocab_size]
logits = self.joiner(am_pruned, lm_pruned)
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned,
project_input=False)
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits, logits=logits,