mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
minor fixes
This commit is contained in:
parent
e102088fed
commit
c195a12a36
@ -107,7 +107,7 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from text_normalization import remove_punc_to_upper,
|
from text_normalization import remove_punc_to_upper
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -174,10 +174,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir", type=str, default="zipformer/exp", help="The experiment dir",
|
||||||
type=str,
|
|
||||||
default="zipformer/exp",
|
|
||||||
help="The experiment dir",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -285,7 +282,7 @@ def get_parser():
|
|||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""Set to True, if the model was trained on texts with casing
|
help="""Set to True, if the model was trained on texts with casing
|
||||||
and punctuation."""
|
and punctuation.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -352,9 +349,7 @@ def decode_one_batch(
|
|||||||
pad_len = 30
|
pad_len = 30
|
||||||
feature_lens += pad_len
|
feature_lens += pad_len
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature, pad=(0, 0, 0, pad_len), value=LOG_EPS,
|
||||||
pad=(0, 0, 0, pad_len),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||||
@ -404,9 +399,7 @@ def decode_one_batch(
|
|||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens,
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -434,9 +427,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model,
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -505,9 +496,6 @@ def decode_dataset(
|
|||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
texts = [
|
|
||||||
simple_normalization(t) for t in texts
|
|
||||||
] # Do a simple normalization, as this is done during training
|
|
||||||
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
@ -537,7 +525,6 @@ def decode_dataset(
|
|||||||
|
|
||||||
results[f"{name}_norm"].extend(this_batch)
|
results[f"{name}_norm"].extend(this_batch)
|
||||||
|
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
@ -786,9 +773,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
params=params,
|
params=params, test_set_name=test_set, results_dict=results_dict,
|
||||||
test_set_name=test_set,
|
|
||||||
results_dict=results_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
@ -43,8 +43,6 @@ def text_normalization(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
assert simple_cleanup("I like this 《book>") == "I like this <book>"
|
|
||||||
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
|
assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
|
||||||
assert (
|
assert (
|
||||||
text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
|
text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user