minor fixes

This commit is contained in:
pkufool 2023-11-21 17:51:38 +08:00
parent e102088fed
commit c195a12a36
2 changed files with 7 additions and 24 deletions

View File

@ -107,7 +107,7 @@ from beam_search import (
modified_beam_search,
)
from lhotse.cut import Cut
from text_normalization import remove_punc_to_upper,
from text_normalization import remove_punc_to_upper
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
@ -174,10 +174,7 @@ def get_parser():
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
"--exp-dir", type=str, default="zipformer/exp", help="The experiment dir",
)
parser.add_argument(
@ -285,7 +282,7 @@ def get_parser():
type=str2bool,
default=False,
help="""Set to True, if the model was trained on texts with casing
and punctuation."""
and punctuation.""",
)
parser.add_argument(
@ -352,9 +349,7 @@ def decode_one_batch(
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
feature, pad=(0, 0, 0, pad_len), value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
@ -404,9 +399,7 @@ def decode_one_batch(
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -434,9 +427,7 @@ def decode_one_batch(
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
)
else:
raise ValueError(
@ -505,9 +496,6 @@ def decode_dataset(
warnings.simplefilter("ignore")
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [
simple_normalization(t) for t in texts
] # Do a simple normalization, as this is done during training
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
@ -537,7 +525,6 @@ def decode_dataset(
results[f"{name}_norm"].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
@ -786,9 +773,7 @@ def main():
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
params=params, test_set_name=test_set, results_dict=results_dict,
)
logging.info("Done!")

View File

@ -43,8 +43,6 @@ def text_normalization(text: str) -> str:
if __name__ == "__main__":
assert simple_cleanup("I like this 《book>") == "I like this <book>"
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
assert (
text_normalization("Hello Mrs st 21st world 3rd she 99th MR")