Merge remote-tracking branch 'origin/modified-conformer-with-multi-datasets' into modified-conformer-with-multi-datasets

This commit is contained in:
Fangjun Kuang 2022-04-20 17:22:29 +08:00
commit e9f0975868

View File

@ -520,6 +520,11 @@ def main():
model.eval()
model.device = device
# In beam_search.py, we are using model.decoder() and model.joiner(),
# so we have to switch to the branch for the GigaSpeech dataset.
model.decoder = model.decoder_giga
model.joiner = model.joiner_giga
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else: