change model.device to next(model.parameters()).device for decoding

This commit is contained in:
yaozengwei 2022-05-06 22:20:14 +08:00
parent dd439b1906
commit b1e9d2186d
2 changed files with 6 additions and 7 deletions

View File

@ -276,7 +276,7 @@ def greedy_search(
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64
@ -350,7 +350,7 @@ def greedy_search_batch(
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
device = next(model.parameters()).device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
@ -580,7 +580,7 @@ def modified_beam_search(
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
@ -705,7 +705,7 @@ def _deprecated_modified_beam_search(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
T = encoder_out.size(1)
@ -813,7 +813,7 @@ def beam_search(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size,

View File

@ -250,7 +250,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
@ -560,7 +560,6 @@ def main():
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)