mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix torch.nn.Embedding error for torch below 1.8.0 (#198)
This commit is contained in:
parent
5ae80dfca7
commit
35ecd7e562
@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
blank_id = model.decoder.blank_id
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
|
||||
1, 1
|
||||
)
|
||||
decoder_out, (h, c) = model.decoder(sos)
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
@ -99,6 +99,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
|
||||
|
@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
blank_id = model.decoder.blank_id
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
|
||||
1, 1
|
||||
)
|
||||
decoder_out, (h, c) = model.decoder(sos)
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
@ -101,6 +101,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=sos_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
|
||||
|
@ -48,7 +48,7 @@ def greedy_search(
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
@ -93,6 +93,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user