Fix decode_stream.py (#1208)

* FIx decode_stream.py

* Update decode_stream.py
This commit is contained in:
Yifan Yang 2023-08-09 09:40:58 +08:00 committed by GitHub
parent 1ee251c8b3
commit 00256a7669
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -79,12 +79,12 @@ class DecodeStream(object):
self.pad_length = 7 + 2 * 3 self.pad_length = 7 + 2 * 3
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size self.hyp = [-1] * (params.context_size - 1) + [params.blank_id]
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList() self.hyps = HypothesisList()
self.hyps.add( self.hyps.add(
Hypothesis( Hypothesis(
ys=[params.blank_id] * params.context_size, ys=[-1] * (params.context_size - 1) + [params.blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
) )
) )