add chunk width randomization

This commit is contained in:
Desh Raj 2023-03-03 16:27:17 -05:00
parent 601de98eb3
commit e9931b7896
4 changed files with 166 additions and 28 deletions

View File

@ -518,7 +518,11 @@ def save_results(
)
with open(errs_filename, "w") as f:
wer = write_surt_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
num_channels=params.num_channels,
)
test_set_wers[key] = wer
@ -749,12 +753,12 @@ def main():
results_dict=results_dict,
)
# if params.save_masks:
# save_masks(
# params=params,
# test_set_name=f"dev_{ol}",
# masks=masks,
# )
# if params.save_masks:
# save_masks(
# params=params,
# test_set_name=f"dev_{ol}",
# masks=masks,
# )
for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
test_dl = librimix.test_dataloaders(test_set)
@ -773,12 +777,12 @@ def main():
results_dict=results_dict,
)
# if params.save_masks:
# save_masks(
# params=params,
# test_set_name=f"test_{ol}",
# masks=masks,
# )
# if params.save_masks:
# save_masks(
# params=params,
# test_set_name=f"test_{ol}",
# masks=masks,
# )
logging.info("Done!")

View File

@ -1,3 +1,4 @@
import random
from typing import Optional, Tuple
import torch
@ -190,6 +191,7 @@ class DPRNN(nn.Module):
dropout=0.1,
num_blocks=1,
segment_size=50,
chunk_width_randomization=False,
):
super().__init__()
@ -198,6 +200,7 @@ class DPRNN(nn.Module):
self.hidden_size = hidden_size
self.segment_size = segment_size
self.chunk_width_randomization = chunk_width_randomization
self.input_embed = nn.Sequential(
ScaledLinear(feature_dim, input_size),
@ -243,7 +246,11 @@ class DPRNN(nn.Module):
input = self.input_embed(input)
B, T, D = input.shape
input, rest = split_feature(input.transpose(1, 2), self.segment_size)
if self.chunk_width_randomization and self.training:
segment_size = random.randint(self.segment_size // 2, self.segment_size)
else:
segment_size = self.segment_size
input, rest = split_feature(input.transpose(1, 2), segment_size)
# input shape: batch, N, dim1, dim2
# apply RNN on dim1 first and then dim2
# output shape: B, output_size, dim1, dim2
@ -291,6 +298,7 @@ if __name__ == "__main__":
dropout=0.1,
num_blocks=3,
segment_size=20,
chunk_width_randomization=True,
)
input = torch.randn(2, 1002, 80)
print(model(input).shape)

View File

@ -104,24 +104,31 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--num-mask-encoder-layers",
type=int,
default=4,
help="Number of layers in the SkiM based mask encoder.",
help="Number of layers in the DPRNN based mask encoder.",
)
parser.add_argument(
"--mask-encoder-dim",
type=int,
default=256,
help="Hidden dimension of the LSTM blocks in SkiM.",
help="Hidden dimension of the LSTM blocks in DPRNN.",
)
parser.add_argument(
"--mask-encoder-segment-size",
type=int,
default=32,
help="Segment size of the SegLSTM in SkiM. Ideally, this should be equal to the "
help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the "
"decode-chunk-length of the zipformer encoder.",
)
parser.add_argument(
"--chunk-width-randomization",
type=bool,
default=False,
help="Whether to randomize the chunk width in DPRNN.",
)
# Zipformer config is based on:
# https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740
parser.add_argument(
@ -508,6 +515,7 @@ def get_mask_encoder_model(params: AttributeDict) -> nn.Module:
output_size=params.feature_dim * params.num_channels,
segment_size=params.mask_encoder_segment_size,
num_blocks=params.num_mask_encoder_layers,
chunk_width_randomization=params.chunk_width_randomization,
)
return mask_encoder

View File

@ -827,6 +827,7 @@ def write_surt_error_stats(
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
num_channels: int = 2,
) -> float:
"""Write statistics based on predicted results and reference transcripts for SURT
multi-talker ASR systems. The difference between this and the `write_error_stats`
@ -867,20 +868,137 @@ def write_surt_error_stats(
"""
from meeteval.wer import wer
wers = []
assignments = []
for cut_id, ref, hyp in results:
orc_wer = wer.orc_word_error_rate(ref, hyp)
wers.append(orc_wer)
assignments.append(orc_wer.assignment)
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
ref_lens: List[int] = []
orc_wer = wer.combine_error_rates(*wers)
tot_errs = orc_wer.errors
ref_len = orc_wer.length
tot_err_rate = orc_wer.error_rate
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for cut_id, ref, hyp in results:
# First compute the optimal assignment of references to output channels
orc_wer = wer.orc_word_error_rate(ref, hyp)
assignment = orc_wer.assignment
refs = [[] for _ in range(num_channels)]
# Assign references to channels
for i, ref_text in zip(assignment, ref):
refs[i] += ref_text.split()
hyps = [hyp_text.split() for hyp_text in hyp]
# Now compute the WER for each channel
for ref_c, hyp_c in zip(refs, hyps):
ref_lens.append(len(ref_c))
ali = kaldialign.align(ref_c, hyp_c, ERR)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
ref_len = sum(ref_lens)
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%}")
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
print(f"%WER = {tot_err_rate}", file=f)
return float(tot_err_rate)