Fix torch.nn.Embedding error for torch below 1.8.0 (#198)

This commit is contained in:
Wei Kang 2022-02-06 21:59:54 +08:00 committed by GitHub
parent 5ae80dfca7
commit 35ecd7e562
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 10 additions and 3 deletions

View File

@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0

View File

@ -99,6 +99,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded)

View File

@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0

View File

@ -101,6 +101,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded)

View File

@ -48,7 +48,7 @@ def greedy_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
[blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)

View File

@ -93,6 +93,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out = self.decoder(sos_y_padded)