mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add blank penalty
This commit is contained in:
parent
42513d2e98
commit
4f28e15a1d
@ -626,6 +626,7 @@ def greedy_search_batch(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
|
blank_penalty: float = 0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
) -> Union[List[List[int]], DecodingResults]:
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
@ -701,6 +702,7 @@ def greedy_search_batch(
|
|||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
assert logits.ndim == 2, logits.shape
|
assert logits.ndim == 2, logits.shape
|
||||||
|
logits[:, 0] -= blank_penalty
|
||||||
y = logits.argmax(dim=1).tolist()
|
y = logits.argmax(dim=1).tolist()
|
||||||
emitted = False
|
emitted = False
|
||||||
for i, v in enumerate(y):
|
for i, v in enumerate(y):
|
||||||
|
|||||||
@ -292,7 +292,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
buffer_size=30000,
|
buffer_size=300000,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -125,6 +125,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from lhotse.cut import Cut
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -302,6 +303,13 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -414,9 +422,7 @@ def decode_one_batch(
|
|||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, blank_penalty=params.blank_penalty,
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
@ -444,9 +450,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model,
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -625,6 +629,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
@ -751,17 +756,30 @@ def main():
|
|||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
def remove_short_utt(c: Cut):
|
||||||
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
|
if T <= 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
|
||||||
|
)
|
||||||
|
return T > 0
|
||||||
|
|
||||||
dev_cuts = wenetspeech.valid_cuts()
|
dev_cuts = wenetspeech.valid_cuts()
|
||||||
|
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||||
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
test_net_cuts = wenetspeech.test_net_cuts()
|
||||||
|
test_net_cuts = test_net_cuts.filter(remove_short_utt)
|
||||||
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
||||||
|
|
||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||||
|
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||||
|
|
||||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||||
|
# test_sets = ["TEST_MEETING"]
|
||||||
|
# test_dl = [test_meeting_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
@ -774,9 +792,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
params=params,
|
params=params, test_set_name=test_set, results_dict=results_dict,
|
||||||
test_set_name=test_set,
|
|
||||||
results_dict=results_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user