support showing symbol delay in conv emformer

This commit is contained in:
marcoyang 2022-12-29 17:50:38 +08:00
parent e91fbef939
commit d0eb9b1912

View File

@ -98,10 +98,12 @@ from icefall.checkpoint import (
)
from icefall.utils import (
AttributeDict,
DecodingResults,
parse_hyp_and_timestamp,
setup_logger,
store_transcripts,
store_transcripts_and_timestamps,
str2bool,
write_error_stats,
write_error_stats_with_timestamps,
)
LOG_EPS = math.log(1e-10)
@ -237,7 +239,7 @@ def decode_one_batch(
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -287,7 +289,7 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
res = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -295,63 +297,74 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
return_timestamps=True,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
res = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
return_timestamps=True,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
res = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
return_timestamps=True,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
tokens = []
timestamps = []
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
res = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
return_timestamps=True,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
res = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
return_timestamps=True,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
tokens.extend(res.tokens)
timestamps.extend(res.timestamps)
res = DecodingResults(hyps=tokens, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp(
decoding_method=params.decoding_method,
res=res,
sp=sp,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
)
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
return {"greedy_search": (hyps, timestamps)}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
): (hyps, timestamps)
}
else:
return {f"beam_size_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
def decode_dataset(
@ -360,7 +373,7 @@ def decode_dataset(
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
) ->Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
"""Decode dataset.
Args:
@ -378,9 +391,12 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
Its value is a list of tuples. Each tuple contains five elements:
- cut_id
- reference transcript
- predicted result
- timestamp of reference transcript
- timestamp of predicted result
"""
num_cuts = 0
@ -390,14 +406,26 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
timestamps_ref = []
for cut in batch["supervisions"]["cut"]:
for s in cut.supervisions:
time = []
if s.alignment is not None and "word" in s.alignment:
time = [
aliword.start
for aliword in s.alignment["word"]
if aliword.symbol != ""
]
timestamps_ref.append(time)
hyps_dict = decode_one_batch(
params=params,
@ -407,12 +435,16 @@ def decode_dataset(
batch=batch,
)
for name, hyps in hyps_dict.items():
for name, (hyps, timestamps_hyp) in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
timestamps_ref
)
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp))
results[name].extend(this_batch)
@ -428,15 +460,19 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
results_dict: Dict[
str,
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
],
):
test_set_wers = dict()
test_set_delays = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts_and_timestamps(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -445,10 +481,11 @@ def save_results(
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
@ -461,6 +498,19 @@ def save_results(
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
delays_info = (
params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f)
for key, val in test_set_delays:
print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
file=f,
)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
@ -468,6 +518,13 @@ def save_results(
note = ""
logging.info(s)
s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
@ -517,7 +574,7 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
@ -586,9 +643,9 @@ def main():
)
)
else:
assert params.avg > 0
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(