minor fixes

This commit is contained in:
pkufool 2023-11-21 17:51:38 +08:00
parent e102088fed
commit c195a12a36
2 changed files with 7 additions and 24 deletions

View File

@ -107,7 +107,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from lhotse.cut import Cut from lhotse.cut import Cut
from text_normalization import remove_punc_to_upper, from text_normalization import remove_punc_to_upper
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -174,10 +174,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir", type=str, default="zipformer/exp", help="The experiment dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
) )
parser.add_argument( parser.add_argument(
@ -285,7 +282,7 @@ def get_parser():
type=str2bool, type=str2bool,
default=False, default=False,
help="""Set to True, if the model was trained on texts with casing help="""Set to True, if the model was trained on texts with casing
and punctuation.""" and punctuation.""",
) )
parser.add_argument( parser.add_argument(
@ -352,9 +349,7 @@ def decode_one_batch(
pad_len = 30 pad_len = 30
feature_lens += pad_len feature_lens += pad_len
feature = torch.nn.functional.pad( feature = torch.nn.functional.pad(
feature, feature, pad=(0, 0, 0, pad_len), value=LOG_EPS,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
) )
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
@ -404,9 +399,7 @@ def decode_one_batch(
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -434,9 +427,7 @@ def decode_one_batch(
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, model=model, encoder_out=encoder_out_i, beam=params.beam_size,
encoder_out=encoder_out_i,
beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(
@ -505,9 +496,6 @@ def decode_dataset(
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [
simple_normalization(t) for t in texts
] # Do a simple normalization, as this is done during training
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
@ -537,7 +525,6 @@ def decode_dataset(
results[f"{name}_norm"].extend(this_batch) results[f"{name}_norm"].extend(this_batch)
num_cuts += len(texts) num_cuts += len(texts)
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
@ -786,9 +773,7 @@ def main():
) )
save_results( save_results(
params=params, params=params, test_set_name=test_set, results_dict=results_dict,
test_set_name=test_set,
results_dict=results_dict,
) )
logging.info("Done!") logging.info("Done!")

View File

@ -43,8 +43,6 @@ def text_normalization(text: str) -> str:
if __name__ == "__main__": if __name__ == "__main__":
assert simple_cleanup("I like this 《book>") == "I like this <book>"
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK" assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
assert ( assert (
text_normalization("Hello Mrs st 21st world 3rd she 99th MR") text_normalization("Hello Mrs st 21st world 3rd she 99th MR")