minor fixes

This commit is contained in:
JinZr 2023-07-25 16:05:57 +08:00
parent 90cb518398
commit 3b4fa4863f

View File

@ -265,7 +265,7 @@ class AsrModel(nn.Module):
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner( logits = self.joiner(
am_pruned, lm_pruned, encoder_out_lens, project_input=False am_pruned, lm_pruned, encoder_out_lens, apply_attn=True, project_input=False
) )
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):