deactivate beam search temporarily for speed

This commit is contained in:
marcoyang 2024-03-28 16:17:05 +08:00
parent ebc0f3b052
commit 360f208037

View File

@ -3,6 +3,7 @@
# Fangjun Kuang,
# Wei Kang)
# 2024 Yuekai Zhang
# 2024 Xiaomi Corporation Xiaoyu Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -145,26 +146,6 @@ def remove_punctuation(text: str or List[str]):
raise Exception(f"Not support type {type(text)}")
def to_simple(text: str or List[str]):
"""Convert traditional Chinese to simplified Chinese.
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings converted to simplified Chinese.
"""
if isinstance(text, str):
text = convert(text, "zh-cn")
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = convert(t, "zh-cn")
result_text.append(t)
return result_text
else:
raise Exception(f"Not support type{type(text)}")
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -417,8 +398,8 @@ def main():
options = whisper.DecodingOptions(
task="transcribe",
language="en",
# without_timestamps=True,
# beam_size=params.beam_size,
without_timestamps=True,
#beam_size=params.beam_size,
)
params.decoding_options = options
params.cleaner = BasicTextNormalizer()
@ -481,12 +462,17 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
def remove_short_and_long_utt(c):
if c.duration < 1.0 or c.duration > 30.0:
return False
return True
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts().subset(first=200)
test_other_cuts = librispeech.test_other_cuts().subset(first=200)
test_clean_cuts = librispeech.test_clean_cuts().filter(remove_short_and_long_utt)
test_other_cuts = librispeech.test_other_cuts().filter(remove_short_and_long_utt)
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)