mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
remove redundant test lines
This commit is contained in:
parent
fb45b95c90
commit
b62fd917ae
@ -18,7 +18,6 @@ import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import k2
|
||||
|
||||
from icefall.utils import add_eos, add_sos, make_pad_mask
|
||||
|
||||
@ -122,9 +121,6 @@ class RnnLmModel(torch.nn.Module):
|
||||
|
||||
return nll_loss
|
||||
|
||||
def get_init_states(self, sos):
|
||||
p = next(self.parameters())
|
||||
|
||||
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
|
||||
device = next(self.parameters()).device
|
||||
batch_size = len(token_lens)
|
||||
@ -165,8 +161,12 @@ class RnnLmModel(torch.nn.Module):
|
||||
if state:
|
||||
h, c = state
|
||||
else:
|
||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
|
||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
|
||||
h = torch.zeros(
|
||||
self.rnn.num_layers, batch_size, self.rnn.input_size
|
||||
).to(device)
|
||||
c = torch.zeros(
|
||||
self.rnn.num_layers, batch_size, self.rnn.input_size
|
||||
).to(device)
|
||||
|
||||
embedding = self.input_embedding(tokens)
|
||||
rnn_out, states = self.rnn(embedding, (h, c))
|
||||
@ -174,13 +174,19 @@ class RnnLmModel(torch.nn.Module):
|
||||
|
||||
return logits[:, 0].log_softmax(-1), states
|
||||
|
||||
def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None):
|
||||
def forward_with_state(
|
||||
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
|
||||
):
|
||||
batch_size = len(token_lens)
|
||||
if state:
|
||||
h, c = state
|
||||
else:
|
||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
||||
h = torch.zeros(
|
||||
self.rnn.num_layers, batch_size, self.rnn.input_size
|
||||
)
|
||||
c = torch.zeros(
|
||||
self.rnn.num_layers, batch_size, self.rnn.input_size
|
||||
)
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
@ -206,39 +212,3 @@ class RnnLmModel(torch.nn.Module):
|
||||
logits = self.output_linear(rnn_out)
|
||||
|
||||
return logits, states
|
||||
|
||||
if __name__=="__main__":
|
||||
LM = RnnLmModel(500, 2048, 2048, 3, True)
|
||||
h0 = torch.zeros(3, 1, 2048)
|
||||
c0 = torch.zeros(3, 1, 2048)
|
||||
seq = [[0,1,2,3]]
|
||||
seq_lens = [len(s) for s in seq]
|
||||
tokens = k2.RaggedTensor(seq)
|
||||
output1, state = LM.forward_with_state(
|
||||
tokens,
|
||||
seq_lens,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
state=(h0,c0)
|
||||
)
|
||||
seq = [[0,1,2,3,4]]
|
||||
seq_lens = [len(s) for s in seq]
|
||||
tokens = k2.RaggedTensor(seq)
|
||||
output2, _ = LM.forward_with_state(
|
||||
tokens,
|
||||
seq_lens,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
state=(h0,c0)
|
||||
)
|
||||
|
||||
seq = [[4]]
|
||||
seq_lens = [len(s) for s in seq]
|
||||
output3 = LM.score_token(seq, seq_lens, state)
|
||||
|
||||
print("Finished")
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user