Fix decoding the gigaspeech dataset.

We have to use the decoder/joiner networks for the GigaSpeech dataset.
This commit is contained in:
Fangjun Kuang 2022-04-18 15:23:07 +08:00
parent e32641d1df
commit a31207f5b3

View File

@ -520,6 +520,11 @@ def main():
model.eval()
model.device = device
# In beam_search.py, we are using model.decoder() and model.joiner(),
# so we have to switch to the branch for the GigaSpeech dataset.
model.decoder = model.decoder_giga
model.joiner = model.joiner_giga
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else: