From 360f20803731a60e409400784dbf5931af50959e Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 28 Mar 2024 16:17:05 +0800 Subject: [PATCH] deactivate beam search temporarily for speed --- egs/librispeech/ASR/whisper/decode.py | 34 ++++++++------------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/whisper/decode.py b/egs/librispeech/ASR/whisper/decode.py index 24f61f17f..83d33418d 100755 --- a/egs/librispeech/ASR/whisper/decode.py +++ b/egs/librispeech/ASR/whisper/decode.py @@ -3,6 +3,7 @@ # Fangjun Kuang, # Wei Kang) # 2024 Yuekai Zhang +# 2024 Xiaomi Corporation Xiaoyu Yang # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -145,26 +146,6 @@ def remove_punctuation(text: str or List[str]): raise Exception(f"Not support type {type(text)}") -def to_simple(text: str or List[str]): - """Convert traditional Chinese to simplified Chinese. - Args: - text: It can be a string or a list of strings. - Returns: - Return a string or a list of strings converted to simplified Chinese. - """ - if isinstance(text, str): - text = convert(text, "zh-cn") - return text - elif isinstance(text, list): - result_text = [] - for t in text: - t = convert(t, "zh-cn") - result_text.append(t) - return result_text - else: - raise Exception(f"Not support type{type(text)}") - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -417,8 +398,8 @@ def main(): options = whisper.DecodingOptions( task="transcribe", language="en", - # without_timestamps=True, - # beam_size=params.beam_size, + without_timestamps=True, + #beam_size=params.beam_size, ) params.decoding_options = options params.cleaner = BasicTextNormalizer() @@ -481,12 +462,17 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + def remove_short_and_long_utt(c): + if c.duration < 1.0 or c.duration > 30.0: + return False + return True + # we need cut ids to display recognition results. args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts().subset(first=200) - test_other_cuts = librispeech.test_other_cuts().subset(first=200) + test_clean_cuts = librispeech.test_clean_cuts().filter(remove_short_and_long_utt) + test_other_cuts = librispeech.test_other_cuts().filter(remove_short_and_long_utt) test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts)