mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add chunk width randomization
This commit is contained in:
parent
601de98eb3
commit
e9931b7896
@ -518,7 +518,11 @@ def save_results(
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_surt_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
num_channels=params.num_channels,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
@ -749,12 +753,12 @@ def main():
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
# if params.save_masks:
|
||||
# save_masks(
|
||||
# params=params,
|
||||
# test_set_name=f"dev_{ol}",
|
||||
# masks=masks,
|
||||
# )
|
||||
# if params.save_masks:
|
||||
# save_masks(
|
||||
# params=params,
|
||||
# test_set_name=f"dev_{ol}",
|
||||
# masks=masks,
|
||||
# )
|
||||
|
||||
for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
|
||||
test_dl = librimix.test_dataloaders(test_set)
|
||||
@ -773,12 +777,12 @@ def main():
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
# if params.save_masks:
|
||||
# save_masks(
|
||||
# params=params,
|
||||
# test_set_name=f"test_{ol}",
|
||||
# masks=masks,
|
||||
# )
|
||||
# if params.save_masks:
|
||||
# save_masks(
|
||||
# params=params,
|
||||
# test_set_name=f"test_{ol}",
|
||||
# masks=masks,
|
||||
# )
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -190,6 +191,7 @@ class DPRNN(nn.Module):
|
||||
dropout=0.1,
|
||||
num_blocks=1,
|
||||
segment_size=50,
|
||||
chunk_width_randomization=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -198,6 +200,7 @@ class DPRNN(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.segment_size = segment_size
|
||||
self.chunk_width_randomization = chunk_width_randomization
|
||||
|
||||
self.input_embed = nn.Sequential(
|
||||
ScaledLinear(feature_dim, input_size),
|
||||
@ -243,7 +246,11 @@ class DPRNN(nn.Module):
|
||||
input = self.input_embed(input)
|
||||
B, T, D = input.shape
|
||||
|
||||
input, rest = split_feature(input.transpose(1, 2), self.segment_size)
|
||||
if self.chunk_width_randomization and self.training:
|
||||
segment_size = random.randint(self.segment_size // 2, self.segment_size)
|
||||
else:
|
||||
segment_size = self.segment_size
|
||||
input, rest = split_feature(input.transpose(1, 2), segment_size)
|
||||
# input shape: batch, N, dim1, dim2
|
||||
# apply RNN on dim1 first and then dim2
|
||||
# output shape: B, output_size, dim1, dim2
|
||||
@ -291,6 +298,7 @@ if __name__ == "__main__":
|
||||
dropout=0.1,
|
||||
num_blocks=3,
|
||||
segment_size=20,
|
||||
chunk_width_randomization=True,
|
||||
)
|
||||
input = torch.randn(2, 1002, 80)
|
||||
print(model(input).shape)
|
||||
|
||||
@ -104,24 +104,31 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--num-mask-encoder-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of layers in the SkiM based mask encoder.",
|
||||
help="Number of layers in the DPRNN based mask encoder.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-encoder-dim",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Hidden dimension of the LSTM blocks in SkiM.",
|
||||
help="Hidden dimension of the LSTM blocks in DPRNN.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-encoder-segment-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Segment size of the SegLSTM in SkiM. Ideally, this should be equal to the "
|
||||
help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the "
|
||||
"decode-chunk-length of the zipformer encoder.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk-width-randomization",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to randomize the chunk width in DPRNN.",
|
||||
)
|
||||
|
||||
# Zipformer config is based on:
|
||||
# https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740
|
||||
parser.add_argument(
|
||||
@ -508,6 +515,7 @@ def get_mask_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
output_size=params.feature_dim * params.num_channels,
|
||||
segment_size=params.mask_encoder_segment_size,
|
||||
num_blocks=params.num_mask_encoder_layers,
|
||||
chunk_width_randomization=params.chunk_width_randomization,
|
||||
)
|
||||
return mask_encoder
|
||||
|
||||
|
||||
140
icefall/utils.py
140
icefall/utils.py
@ -827,6 +827,7 @@ def write_surt_error_stats(
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
num_channels: int = 2,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts for SURT
|
||||
multi-talker ASR systems. The difference between this and the `write_error_stats`
|
||||
@ -867,20 +868,137 @@ def write_surt_error_stats(
|
||||
"""
|
||||
from meeteval.wer import wer
|
||||
|
||||
wers = []
|
||||
assignments = []
|
||||
for cut_id, ref, hyp in results:
|
||||
orc_wer = wer.orc_word_error_rate(ref, hyp)
|
||||
wers.append(orc_wer)
|
||||
assignments.append(orc_wer.assignment)
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
ref_lens: List[int] = []
|
||||
|
||||
orc_wer = wer.combine_error_rates(*wers)
|
||||
tot_errs = orc_wer.errors
|
||||
ref_len = orc_wer.length
|
||||
tot_err_rate = orc_wer.error_rate
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
for cut_id, ref, hyp in results:
|
||||
# First compute the optimal assignment of references to output channels
|
||||
orc_wer = wer.orc_word_error_rate(ref, hyp)
|
||||
assignment = orc_wer.assignment
|
||||
refs = [[] for _ in range(num_channels)]
|
||||
# Assign references to channels
|
||||
for i, ref_text in zip(assignment, ref):
|
||||
refs[i] += ref_text.split()
|
||||
hyps = [hyp_text.split() for hyp_text in hyp]
|
||||
# Now compute the WER for each channel
|
||||
for ref_c, hyp_c in zip(refs, hyps):
|
||||
ref_lens.append(len(ref_c))
|
||||
ali = kaldialign.align(ref_c, hyp_c, ERR)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word
|
||||
if ref_word == hyp_word
|
||||
else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
ref_len = sum(ref_lens)
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
if enable_log:
|
||||
logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%}")
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user