mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
change model.device to next(model.parameters()).device for decoding
This commit is contained in:
parent
dd439b1906
commit
b1e9d2186d
@ -276,7 +276,7 @@ def greedy_search(
|
|||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||||
@ -350,7 +350,7 @@ def greedy_search_batch(
|
|||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
@ -580,7 +580,7 @@ def modified_beam_search(
|
|||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
B[i].add(
|
B[i].add(
|
||||||
@ -705,7 +705,7 @@ def _deprecated_modified_beam_search(
|
|||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
@ -813,7 +813,7 @@ def beam_search(
|
|||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size,
|
[blank_id] * context_size,
|
||||||
|
|||||||
@ -250,7 +250,7 @@ def decode_one_batch(
|
|||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
"""
|
"""
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
|
|
||||||
@ -560,7 +560,6 @@ def main():
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user