support showing symbol delay in conv emformer

This commit is contained in:
marcoyang 2022-12-29 17:50:38 +08:00
parent e91fbef939
commit d0eb9b1912

View File

@ -98,10 +98,12 @@ from icefall.checkpoint import (
) )
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
DecodingResults,
parse_hyp_and_timestamp,
setup_logger, setup_logger,
store_transcripts, store_transcripts_and_timestamps,
str2bool, str2bool,
write_error_stats, write_error_stats_with_timestamps,
) )
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -237,7 +239,7 @@ def decode_one_batch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -287,7 +289,7 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( res = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -295,63 +297,74 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( res = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( res = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
tokens = []
timestamps = []
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search( res = greedy_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
return_timestamps=True,
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( res = beam_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
return_timestamps=True,
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) tokens.extend(res.tokens)
timestamps.extend(res.timestamps)
res = DecodingResults(hyps=tokens, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp(
decoding_method=params.decoding_method,
res=res,
sp=sp,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": (hyps, timestamps)}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
): hyps ): (hyps, timestamps)
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
def decode_dataset( def decode_dataset(
@ -360,7 +373,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) ->Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -378,9 +391,12 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements: Its value is a list of tuples. Each tuple contains five elements:
The first is the reference transcript, and the second is the - cut_id
predicted result. - reference transcript
- predicted result
- timestamp of reference transcript
- timestamp of predicted result
""" """
num_cuts = 0 num_cuts = 0
@ -390,15 +406,27 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
timestamps_ref = []
for cut in batch["supervisions"]["cut"]:
for s in cut.supervisions:
time = []
if s.alignment is not None and "word" in s.alignment:
time = [
aliword.start
for aliword in s.alignment["word"]
if aliword.symbol != ""
]
timestamps_ref.append(time)
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
@ -407,12 +435,16 @@ def decode_dataset(
batch=batch, batch=batch,
) )
for name, hyps in hyps_dict.items(): for name, (hyps, timestamps_hyp) in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): timestamps_ref
)
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -428,15 +460,19 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[
str,
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
],
): ):
test_set_wers = dict() test_set_wers = dict()
test_set_delays = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts_and_timestamps(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -445,10 +481,11 @@ def save_results(
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay)
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
@ -461,6 +498,19 @@ def save_results(
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) print("{}\t{}".format(key, val), file=f)
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
delays_info = (
params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f)
for key, val in test_set_delays:
print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
file=f,
)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name) note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -468,6 +518,13 @@ def save_results(
note = "" note = ""
logging.info(s) logging.info(s)
s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
note = ""
logging.info(s)
@torch.no_grad() @torch.no_grad()
def main(): def main():
@ -517,7 +574,7 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py # <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
@ -586,9 +643,9 @@ def main():
) )
) )
else: else:
assert params.avg > 0 assert params.avg > 0, params.avg
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1 assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(