diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 9eef88840..3598a4857 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -159,10 +159,10 @@ class RnnLmModel(torch.nn.Module): if state: h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( device ) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( device ) @@ -179,8 +179,8 @@ class RnnLmModel(torch.nn.Module): if state: h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) device = next(self.parameters()).device