mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
remove redundant test lines
This commit is contained in:
parent
fb45b95c90
commit
b62fd917ae
@ -18,7 +18,6 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import k2
|
|
||||||
|
|
||||||
from icefall.utils import add_eos, add_sos, make_pad_mask
|
from icefall.utils import add_eos, add_sos, make_pad_mask
|
||||||
|
|
||||||
@ -122,9 +121,6 @@ class RnnLmModel(torch.nn.Module):
|
|||||||
|
|
||||||
return nll_loss
|
return nll_loss
|
||||||
|
|
||||||
def get_init_states(self, sos):
|
|
||||||
p = next(self.parameters())
|
|
||||||
|
|
||||||
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
|
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
batch_size = len(token_lens)
|
batch_size = len(token_lens)
|
||||||
@ -154,7 +150,7 @@ class RnnLmModel(torch.nn.Module):
|
|||||||
mask[i, token_lens[i], :] = True
|
mask[i, token_lens[i], :] = True
|
||||||
logits = logits[mask].reshape(batch_size, -1)
|
logits = logits[mask].reshape(batch_size, -1)
|
||||||
|
|
||||||
return logits[:,:].log_softmax(-1), states
|
return logits[:, :].log_softmax(-1), states
|
||||||
|
|
||||||
def clean_cache(self):
|
def clean_cache(self):
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
@ -163,24 +159,34 @@ class RnnLmModel(torch.nn.Module):
|
|||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
batch_size = tokens.size(0)
|
batch_size = tokens.size(0)
|
||||||
if state:
|
if state:
|
||||||
h,c = state
|
h, c = state
|
||||||
else:
|
else:
|
||||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
|
h = torch.zeros(
|
||||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device)
|
self.rnn.num_layers, batch_size, self.rnn.input_size
|
||||||
|
).to(device)
|
||||||
|
c = torch.zeros(
|
||||||
|
self.rnn.num_layers, batch_size, self.rnn.input_size
|
||||||
|
).to(device)
|
||||||
|
|
||||||
embedding = self.input_embedding(tokens)
|
embedding = self.input_embedding(tokens)
|
||||||
rnn_out, states = self.rnn(embedding, (h,c))
|
rnn_out, states = self.rnn(embedding, (h, c))
|
||||||
logits = self.output_linear(rnn_out)
|
logits = self.output_linear(rnn_out)
|
||||||
|
|
||||||
return logits[:,0].log_softmax(-1), states
|
return logits[:, 0].log_softmax(-1), states
|
||||||
|
|
||||||
def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None):
|
def forward_with_state(
|
||||||
|
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
|
||||||
|
):
|
||||||
batch_size = len(token_lens)
|
batch_size = len(token_lens)
|
||||||
if state:
|
if state:
|
||||||
h,c = state
|
h, c = state
|
||||||
else:
|
else:
|
||||||
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
h = torch.zeros(
|
||||||
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
|
self.rnn.num_layers, batch_size, self.rnn.input_size
|
||||||
|
)
|
||||||
|
c = torch.zeros(
|
||||||
|
self.rnn.num_layers, batch_size, self.rnn.input_size
|
||||||
|
)
|
||||||
|
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
@ -202,43 +208,7 @@ class RnnLmModel(torch.nn.Module):
|
|||||||
embedding = self.input_embedding(x_tokens)
|
embedding = self.input_embedding(x_tokens)
|
||||||
|
|
||||||
# Note: We use batch_first==True
|
# Note: We use batch_first==True
|
||||||
rnn_out, states = self.rnn(embedding, (h,c))
|
rnn_out, states = self.rnn(embedding, (h, c))
|
||||||
logits = self.output_linear(rnn_out)
|
logits = self.output_linear(rnn_out)
|
||||||
|
|
||||||
return logits, states
|
return logits, states
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
LM = RnnLmModel(500, 2048, 2048, 3, True)
|
|
||||||
h0 = torch.zeros(3, 1, 2048)
|
|
||||||
c0 = torch.zeros(3, 1, 2048)
|
|
||||||
seq = [[0,1,2,3]]
|
|
||||||
seq_lens = [len(s) for s in seq]
|
|
||||||
tokens = k2.RaggedTensor(seq)
|
|
||||||
output1, state = LM.forward_with_state(
|
|
||||||
tokens,
|
|
||||||
seq_lens,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
state=(h0,c0)
|
|
||||||
)
|
|
||||||
seq = [[0,1,2,3,4]]
|
|
||||||
seq_lens = [len(s) for s in seq]
|
|
||||||
tokens = k2.RaggedTensor(seq)
|
|
||||||
output2, _ = LM.forward_with_state(
|
|
||||||
tokens,
|
|
||||||
seq_lens,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
state=(h0,c0)
|
|
||||||
)
|
|
||||||
|
|
||||||
seq = [[4]]
|
|
||||||
seq_lens = [len(s) for s in seq]
|
|
||||||
output3 = LM.score_token(seq, seq_lens, state)
|
|
||||||
|
|
||||||
print("Finished")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user