remove redundant test lines

This commit is contained in:
marcoyang 2022-11-02 18:17:05 +08:00
parent fb45b95c90
commit b62fd917ae

View File

@ -18,7 +18,6 @@ import logging
import torch
import torch.nn.functional as F
import k2
from icefall.utils import add_eos, add_sos, make_pad_mask
@ -122,9 +121,6 @@ class RnnLmModel(torch.nn.Module):
return nll_loss
def get_init_states(self, sos):
p = next(self.parameters())
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
device = next(self.parameters()).device
batch_size = len(token_lens)
@ -165,8 +161,12 @@ class RnnLmModel(torch.nn.Module):
if state:
h, c = state
else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
h = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
).to(device)
c = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
).to(device)
embedding = self.input_embedding(tokens)
rnn_out, states = self.rnn(embedding, (h, c))
@ -174,13 +174,19 @@ class RnnLmModel(torch.nn.Module):
return logits[:, 0].log_softmax(-1), states
def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None):
def forward_with_state(
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
):
batch_size = len(token_lens)
if state:
h, c = state
else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
h = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
)
c = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
)
device = next(self.parameters()).device
@ -206,39 +212,3 @@ class RnnLmModel(torch.nn.Module):
logits = self.output_linear(rnn_out)
return logits, states
if __name__=="__main__":
LM = RnnLmModel(500, 2048, 2048, 3, True)
h0 = torch.zeros(3, 1, 2048)
c0 = torch.zeros(3, 1, 2048)
seq = [[0,1,2,3]]
seq_lens = [len(s) for s in seq]
tokens = k2.RaggedTensor(seq)
output1, state = LM.forward_with_state(
tokens,
seq_lens,
1,
1,
0,
state=(h0,c0)
)
seq = [[0,1,2,3,4]]
seq_lens = [len(s) for s in seq]
tokens = k2.RaggedTensor(seq)
output2, _ = LM.forward_with_state(
tokens,
seq_lens,
1,
1,
0,
state=(h0,c0)
)
seq = [[4]]
seq_lens = [len(s) for s in seq]
output3 = LM.score_token(seq, seq_lens, state)
print("Finished")