mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fixed rnn_lm model.py (#738)
This commit is contained in:
parent
10472e7ffc
commit
0e325c8782
@ -159,10 +159,10 @@ class RnnLmModel(torch.nn.Module):
|
||||
if state:
|
||||
h, c = state
|
||||
else:
|
||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
|
||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
|
||||
device
|
||||
)
|
||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
|
||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
|
||||
device
|
||||
)
|
||||
|
||||
@ -179,8 +179,8 @@ class RnnLmModel(torch.nn.Module):
|
||||
if state:
|
||||
h, c = state
|
||||
else:
|
||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
|
||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user