From 0e325c8782c8b9178cf0f2b030e49ae64f2b091d Mon Sep 17 00:00:00 2001 From: huangruizhe Date: Wed, 7 Dec 2022 02:43:26 -0500 Subject: [PATCH] Fixed rnn_lm model.py (#738) --- icefall/rnn_lm/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 9eef88840..3598a4857 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -159,10 +159,10 @@ class RnnLmModel(torch.nn.Module): if state: h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( device ) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( device ) @@ -179,8 +179,8 @@ class RnnLmModel(torch.nn.Module): if state: h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) device = next(self.parameters()).device