diff --git a/egs/slu/transducer/decode.py b/egs/slu/transducer/decode.py index a272babd2..661de96ca 100755 --- a/egs/slu/transducer/decode.py +++ b/egs/slu/transducer/decode.py @@ -22,7 +22,7 @@ from typing import List, Tuple import torch import torch.nn as nn -from asr_datamodule import SluDataModule +from transducer.asr_datamodule import SluDataModule from transducer.beam_search import greedy_search from transducer.decoder import Decoder from transducer.encoder import Tdnn @@ -336,7 +336,8 @@ def main(): model=model, ) - save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results) + test_set_name=str(args.feature_dir).split('/')[-2] + save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results) logging.info("Done!")