From b1e9d2186d6771b27220cfc947e5b01cab637498 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 6 May 2022 22:20:14 +0800 Subject: [PATCH] change model.device to next(model.parameters()).device for decoding --- .../ASR/pruned_transducer_stateless2/beam_search.py | 10 +++++----- .../ASR/pruned_transducer_stateless4/decode.py | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index ad492aaa5..fc1285dc7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -276,7 +276,7 @@ def greedy_search( context_size = model.decoder.context_size unk_id = getattr(model, "unk_id", blank_id) - device = model.device + device = next(model.parameters()).device decoder_input = torch.tensor( [blank_id] * context_size, device=device, dtype=torch.int64 @@ -350,7 +350,7 @@ def greedy_search_batch( assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) - device = model.device + device = next(model.parameters()).device batch_size = encoder_out.size(0) T = encoder_out.size(1) @@ -580,7 +580,7 @@ def modified_beam_search( blank_id = model.decoder.blank_id unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device B = [HypothesisList() for _ in range(batch_size)] for i in range(batch_size): B[i].add( @@ -705,7 +705,7 @@ def _deprecated_modified_beam_search( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device T = encoder_out.size(1) @@ -813,7 +813,7 @@ def beam_search( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device decoder_input = torch.tensor( [blank_id] * context_size, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index e06662905..025ebd7bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -250,7 +250,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = model.device + device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3 @@ -560,7 +560,6 @@ def main(): model.to(device) model.eval() - model.device = device if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)