Fixed rnn_lm model.py (#738)

This commit is contained in:
huangruizhe 2022-12-07 02:43:26 -05:00 committed by GitHub
parent 10472e7ffc
commit 0e325c8782
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -159,10 +159,10 @@ class RnnLmModel(torch.nn.Module):
if state: if state:
h, c = state h, c = state
else: else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
device device
) )
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
device device
) )
@ -179,8 +179,8 @@ class RnnLmModel(torch.nn.Module):
if state: if state:
h, c = state h, c = state
else: else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
device = next(self.parameters()).device device = next(self.parameters()).device