deactivate beam search temporarily for speed

This commit is contained in:
marcoyang 2024-03-28 16:17:05 +08:00
parent ebc0f3b052
commit 360f208037

View File

@ -3,6 +3,7 @@
# Fangjun Kuang, # Fangjun Kuang,
# Wei Kang) # Wei Kang)
# 2024 Yuekai Zhang # 2024 Yuekai Zhang
# 2024 Xiaomi Corporation Xiaoyu Yang
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -145,26 +146,6 @@ def remove_punctuation(text: str or List[str]):
raise Exception(f"Not support type {type(text)}") raise Exception(f"Not support type {type(text)}")
def to_simple(text: str or List[str]):
"""Convert traditional Chinese to simplified Chinese.
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings converted to simplified Chinese.
"""
if isinstance(text, str):
text = convert(text, "zh-cn")
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = convert(t, "zh-cn")
result_text.append(t)
return result_text
else:
raise Exception(f"Not support type{type(text)}")
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -417,8 +398,8 @@ def main():
options = whisper.DecodingOptions( options = whisper.DecodingOptions(
task="transcribe", task="transcribe",
language="en", language="en",
# without_timestamps=True, without_timestamps=True,
# beam_size=params.beam_size, #beam_size=params.beam_size,
) )
params.decoding_options = options params.decoding_options = options
params.cleaner = BasicTextNormalizer() params.cleaner = BasicTextNormalizer()
@ -481,12 +462,17 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
def remove_short_and_long_utt(c):
if c.duration < 1.0 or c.duration > 30.0:
return False
return True
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts().subset(first=200) test_clean_cuts = librispeech.test_clean_cuts().filter(remove_short_and_long_utt)
test_other_cuts = librispeech.test_other_cuts().subset(first=200) test_other_cuts = librispeech.test_other_cuts().filter(remove_short_and_long_utt)
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts)