Fix bug with import

This commit is contained in:
Daniel Povey 2022-03-18 16:40:24 +08:00
parent 2dfcd8f117
commit c9f1aeb7d1

View File

@ -459,6 +459,7 @@ class ScaledEmbedding(nn.Module):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
scale = self.scale.exp() scale = self.scale.exp()
if input.numel() < self.num_embeddings: if input.numel() < self.num_embeddings:
return F.embedding( return F.embedding(