From b62fd917ae54fb0305a3f4fac931d850bfe231c1 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 18:17:05 +0800 Subject: [PATCH] remove redundant test lines --- icefall/rnn_lm/model.py | 88 ++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 59 deletions(-) diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 2552f65a6..a6144727a 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -18,7 +18,6 @@ import logging import torch import torch.nn.functional as F -import k2 from icefall.utils import add_eos, add_sos, make_pad_mask @@ -121,9 +120,6 @@ class RnnLmModel(torch.nn.Module): nll_loss = nll_loss.reshape(batch_size, -1) return nll_loss - - def get_init_states(self, sos): - p = next(self.parameters()) def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): device = next(self.parameters()).device @@ -153,35 +149,45 @@ class RnnLmModel(torch.nn.Module): for i in range(batch_size): mask[i, token_lens[i], :] = True logits = logits[mask].reshape(batch_size, -1) - - return logits[:,:].log_softmax(-1), states - + + return logits[:, :].log_softmax(-1), states + def clean_cache(self): self.cache = {} - + def score_token(self, tokens: torch.Tensor, state=None): device = next(self.parameters()).device batch_size = tokens.size(0) if state: - h,c = state + h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) - - embedding = self.input_embedding(tokens) - rnn_out, states = self.rnn(embedding, (h,c)) - logits = self.output_linear(rnn_out) - - return logits[:,0].log_softmax(-1), states + h = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ).to(device) + c = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ).to(device) - def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None): + embedding = self.input_embedding(tokens) + rnn_out, states = self.rnn(embedding, (h, c)) + logits = self.output_linear(rnn_out) + + return logits[:, 0].log_softmax(-1), states + + def forward_with_state( + self, tokens, token_lens, sos_id, eos_id, blank_id, state=None + ): batch_size = len(token_lens) if state: - h,c = state + h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - + h = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ) + c = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ) + device = next(self.parameters()).device sos_tokens = add_sos(tokens, sos_id) @@ -202,43 +208,7 @@ class RnnLmModel(torch.nn.Module): embedding = self.input_embedding(x_tokens) # Note: We use batch_first==True - rnn_out, states = self.rnn(embedding, (h,c)) + rnn_out, states = self.rnn(embedding, (h, c)) logits = self.output_linear(rnn_out) return logits, states - -if __name__=="__main__": - LM = RnnLmModel(500, 2048, 2048, 3, True) - h0 = torch.zeros(3, 1, 2048) - c0 = torch.zeros(3, 1, 2048) - seq = [[0,1,2,3]] - seq_lens = [len(s) for s in seq] - tokens = k2.RaggedTensor(seq) - output1, state = LM.forward_with_state( - tokens, - seq_lens, - 1, - 1, - 0, - state=(h0,c0) - ) - seq = [[0,1,2,3,4]] - seq_lens = [len(s) for s in seq] - tokens = k2.RaggedTensor(seq) - output2, _ = LM.forward_with_state( - tokens, - seq_lens, - 1, - 1, - 0, - state=(h0,c0) - ) - - seq = [[4]] - seq_lens = [len(s) for s in seq] - output3 = LM.score_token(seq, seq_lens, state) - - print("Finished") - - -