Replace nn.Linear with ScaledLinear in simple joiner

This commit is contained in:
Daniel Povey 2022-03-31 12:18:31 +08:00
parent 9a0c2e7fee
commit f75d40c725

View File

@ -63,7 +63,7 @@ class Transducer(nn.Module):
# could perhaps separate this into 2 linear projections, one
# for lm and one for am.
self.simple_joiner = nn.Linear(embedding_dim, vocab_size)
self.simple_joiner = ScaledLinear(embedding_dim, vocab_size)
def forward(
self,