mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add blank penalty
This commit is contained in:
parent
42513d2e98
commit
4f28e15a1d
@ -626,6 +626,7 @@ def greedy_search_batch(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
blank_penalty: float = 0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
@ -701,6 +702,7 @@ def greedy_search_batch(
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
logits[:, 0] -= blank_penalty
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
|
||||
@ -292,7 +292,7 @@ class WenetSpeechAsrDataModule:
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=30000,
|
||||
buffer_size=300000,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -125,6 +125,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from lhotse.cut import Cut
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -302,6 +303,13 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -414,9 +422,7 @@ def decode_one_batch(
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -444,9 +450,7 @@ def decode_one_batch(
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -625,6 +629,7 @@ def main():
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
@ -751,17 +756,30 @@ def main():
|
||||
args.return_cuts = True
|
||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||
|
||||
def remove_short_utt(c: Cut):
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
if T <= 0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
|
||||
)
|
||||
return T > 0
|
||||
|
||||
dev_cuts = wenetspeech.valid_cuts()
|
||||
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
||||
|
||||
test_net_cuts = wenetspeech.test_net_cuts()
|
||||
test_net_cuts = test_net_cuts.filter(remove_short_utt)
|
||||
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
||||
|
||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||
|
||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
# test_sets = ["TEST_MEETING"]
|
||||
# test_dl = [test_meeting_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
@ -774,9 +792,7 @@ def main():
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
params=params, test_set_name=test_set, results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user