remove redundant test lines

This commit is contained in:
marcoyang 2022-11-02 18:17:05 +08:00
parent fb45b95c90
commit b62fd917ae

View File

@ -18,7 +18,6 @@ import logging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import k2
from icefall.utils import add_eos, add_sos, make_pad_mask from icefall.utils import add_eos, add_sos, make_pad_mask
@ -122,9 +121,6 @@ class RnnLmModel(torch.nn.Module):
return nll_loss return nll_loss
def get_init_states(self, sos):
p = next(self.parameters())
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
device = next(self.parameters()).device device = next(self.parameters()).device
batch_size = len(token_lens) batch_size = len(token_lens)
@ -154,7 +150,7 @@ class RnnLmModel(torch.nn.Module):
mask[i, token_lens[i], :] = True mask[i, token_lens[i], :] = True
logits = logits[mask].reshape(batch_size, -1) logits = logits[mask].reshape(batch_size, -1)
return logits[:,:].log_softmax(-1), states return logits[:, :].log_softmax(-1), states
def clean_cache(self): def clean_cache(self):
self.cache = {} self.cache = {}
@ -163,24 +159,34 @@ class RnnLmModel(torch.nn.Module):
device = next(self.parameters()).device device = next(self.parameters()).device
batch_size = tokens.size(0) batch_size = tokens.size(0)
if state: if state:
h,c = state h, c = state
else: else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) h = torch.zeros(
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) self.rnn.num_layers, batch_size, self.rnn.input_size
).to(device)
c = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
).to(device)
embedding = self.input_embedding(tokens) embedding = self.input_embedding(tokens)
rnn_out, states = self.rnn(embedding, (h,c)) rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out) logits = self.output_linear(rnn_out)
return logits[:,0].log_softmax(-1), states return logits[:, 0].log_softmax(-1), states
def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None): def forward_with_state(
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
):
batch_size = len(token_lens) batch_size = len(token_lens)
if state: if state:
h,c = state h, c = state
else: else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) h = torch.zeros(
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) self.rnn.num_layers, batch_size, self.rnn.input_size
)
c = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
)
device = next(self.parameters()).device device = next(self.parameters()).device
@ -202,43 +208,7 @@ class RnnLmModel(torch.nn.Module):
embedding = self.input_embedding(x_tokens) embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True # Note: We use batch_first==True
rnn_out, states = self.rnn(embedding, (h,c)) rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out) logits = self.output_linear(rnn_out)
return logits, states return logits, states
if __name__=="__main__":
LM = RnnLmModel(500, 2048, 2048, 3, True)
h0 = torch.zeros(3, 1, 2048)
c0 = torch.zeros(3, 1, 2048)
seq = [[0,1,2,3]]
seq_lens = [len(s) for s in seq]
tokens = k2.RaggedTensor(seq)
output1, state = LM.forward_with_state(
tokens,
seq_lens,
1,
1,
0,
state=(h0,c0)
)
seq = [[0,1,2,3,4]]
seq_lens = [len(s) for s in seq]
tokens = k2.RaggedTensor(seq)
output2, _ = LM.forward_with_state(
tokens,
seq_lens,
1,
1,
0,
state=(h0,c0)
)
seq = [[4]]
seq_lens = [len(s) for s in seq]
output3 = LM.score_token(seq, seq_lens, state)
print("Finished")