mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fixed rnn_lm model.py (#738)
This commit is contained in:
parent
10472e7ffc
commit
0e325c8782
@ -159,10 +159,10 @@ class RnnLmModel(torch.nn.Module):
|
|||||||
if state:
|
if state:
|
||||||
h, c = state
|
h, c = state
|
||||||
else:
|
else:
|
||||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
|
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
|
||||||
device
|
device
|
||||||
)
|
)
|
||||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
|
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
|
||||||
device
|
device
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -179,8 +179,8 @@ class RnnLmModel(torch.nn.Module):
|
|||||||
if state:
|
if state:
|
||||||
h, c = state
|
h, c = state
|
||||||
else:
|
else:
|
||||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
|
||||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
|
||||||
|
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user