mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
deactivate beam search temporarily for speed
This commit is contained in:
parent
ebc0f3b052
commit
360f208037
@ -3,6 +3,7 @@
|
|||||||
# Fangjun Kuang,
|
# Fangjun Kuang,
|
||||||
# Wei Kang)
|
# Wei Kang)
|
||||||
# 2024 Yuekai Zhang
|
# 2024 Yuekai Zhang
|
||||||
|
# 2024 Xiaomi Corporation Xiaoyu Yang
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -145,26 +146,6 @@ def remove_punctuation(text: str or List[str]):
|
|||||||
raise Exception(f"Not support type {type(text)}")
|
raise Exception(f"Not support type {type(text)}")
|
||||||
|
|
||||||
|
|
||||||
def to_simple(text: str or List[str]):
|
|
||||||
"""Convert traditional Chinese to simplified Chinese.
|
|
||||||
Args:
|
|
||||||
text: It can be a string or a list of strings.
|
|
||||||
Returns:
|
|
||||||
Return a string or a list of strings converted to simplified Chinese.
|
|
||||||
"""
|
|
||||||
if isinstance(text, str):
|
|
||||||
text = convert(text, "zh-cn")
|
|
||||||
return text
|
|
||||||
elif isinstance(text, list):
|
|
||||||
result_text = []
|
|
||||||
for t in text:
|
|
||||||
t = convert(t, "zh-cn")
|
|
||||||
result_text.append(t)
|
|
||||||
return result_text
|
|
||||||
else:
|
|
||||||
raise Exception(f"Not support type{type(text)}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -417,7 +398,7 @@ def main():
|
|||||||
options = whisper.DecodingOptions(
|
options = whisper.DecodingOptions(
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
language="en",
|
language="en",
|
||||||
# without_timestamps=True,
|
without_timestamps=True,
|
||||||
#beam_size=params.beam_size,
|
#beam_size=params.beam_size,
|
||||||
)
|
)
|
||||||
params.decoding_options = options
|
params.decoding_options = options
|
||||||
@ -481,12 +462,17 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
def remove_short_and_long_utt(c):
|
||||||
|
if c.duration < 1.0 or c.duration > 30.0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts().subset(first=200)
|
test_clean_cuts = librispeech.test_clean_cuts().filter(remove_short_and_long_utt)
|
||||||
test_other_cuts = librispeech.test_other_cuts().subset(first=200)
|
test_other_cuts = librispeech.test_other_cuts().filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user