mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
minor fixes
This commit is contained in:
parent
e102088fed
commit
c195a12a36
@ -107,7 +107,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from text_normalization import remove_punc_to_upper,
|
||||
from text_normalization import remove_punc_to_upper
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -174,10 +174,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
"--exp-dir", type=str, default="zipformer/exp", help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -285,7 +282,7 @@ def get_parser():
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Set to True, if the model was trained on texts with casing
|
||||
and punctuation."""
|
||||
and punctuation.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -352,9 +349,7 @@ def decode_one_batch(
|
||||
pad_len = 30
|
||||
feature_lens += pad_len
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, pad_len),
|
||||
value=LOG_EPS,
|
||||
feature, pad=(0, 0, 0, pad_len), value=LOG_EPS,
|
||||
)
|
||||
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||
@ -404,9 +399,7 @@ def decode_one_batch(
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -434,9 +427,7 @@ def decode_one_batch(
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -505,9 +496,6 @@ def decode_dataset(
|
||||
warnings.simplefilter("ignore")
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [
|
||||
simple_normalization(t) for t in texts
|
||||
] # Do a simple normalization, as this is done during training
|
||||
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
@ -537,7 +525,6 @@ def decode_dataset(
|
||||
|
||||
results[f"{name}_norm"].extend(this_batch)
|
||||
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
@ -786,9 +773,7 @@ def main():
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
params=params, test_set_name=test_set, results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
@ -43,8 +43,6 @@ def text_normalization(text: str) -> str:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
assert simple_cleanup("I like this 《book>") == "I like this <book>"
|
||||
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
|
||||
assert (
|
||||
text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
|
||||
|
Loading…
x
Reference in New Issue
Block a user