change model.device to next(model.parameters()).device for decoding

This commit is contained in:
yaozengwei 2022-05-06 22:20:14 +08:00
parent dd439b1906
commit b1e9d2186d
2 changed files with 6 additions and 7 deletions

View File

@ -276,7 +276,7 @@ def greedy_search(
context_size = model.decoder.context_size context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
device = model.device device = next(model.parameters()).device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64 [blank_id] * context_size, device=device, dtype=torch.int64
@ -350,7 +350,7 @@ def greedy_search_batch(
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device device = next(model.parameters()).device
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
T = encoder_out.size(1) T = encoder_out.size(1)
@ -580,7 +580,7 @@ def modified_beam_search(
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size): for i in range(batch_size):
B[i].add( B[i].add(
@ -705,7 +705,7 @@ def _deprecated_modified_beam_search(
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
T = encoder_out.size(1) T = encoder_out.size(1)
@ -813,7 +813,7 @@ def beam_search(
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, [blank_id] * context_size,

View File

@ -250,7 +250,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
""" """
device = model.device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
@ -560,7 +560,6 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)