#!/usr/bin/env python3 # Copyright 2022 Xiaomi Corporation (Author: Liyong Guo) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import logging from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple import torch from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, write_error_stats, ) def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--exp-dir", type=Path, default="pruned_transducer_stateless6/exp/", help="The experiment dir", ) return parser def decode_dataset( dl: torch.utils.data.DataLoader, hubert_model: HubertXlargeFineTuned, params: AttributeDict, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: dl: PyTorch's dataloader containing the dataset to decode. model: The neural model. Returns: Return a dict, whose key is decoding method "ctc_greedy_search". Its value is a list of tuples. Each tuple contains two elements: The first is the reference transcript, and the second is the predicted result. """ results = [] num_cuts = 0 try: num_batches = len(dl) except TypeError: num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): hyps = hubert_model.ctc_greedy_search(batch) texts = batch["supervisions"]["text"] assert len(hyps) == len(texts) this_batch = [] for hyp_text, ref_text in zip(hyps, texts): ref_words = ref_text.split() hyp_words = hyp_text.split() this_batch.append((ref_words, hyp_words)) results["ctc_greedy_search"].extend(this_batch) num_cuts += len(texts) if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results def save_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{key}.txt" store_transcripts(filename=recog_path, texts=results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{key}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer logging.info( "Wrote detailed error stats to {}".format(errs_filename) ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: s += "{}\t{}{}\n".format(key, val, note) note = "" logging.info(s) @torch.no_grad() def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) HubertXlargeFineTuned.add_arguments(parser) args = parser.parse_args() params = AttributeDict() params.update(vars(args)) # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) params.res_dir = ( params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" ) setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") logging.info(params) device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) logging.info(f"device: {device}") params.device = device hubert_model = HubertXlargeFineTuned(params) librispeech = LibriSpeechAsrDataModule(params) test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts) test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( dl=test_dl, hubert_model=hubert_model, params=params, ) save_results( params=params, test_set_name=test_set, results_dict=results_dict ) logging.info("Done!") torch.set_num_threads(1) torch.set_num_interop_threads(1) if __name__ == "__main__": main()