mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add chunk width randomization
This commit is contained in:
parent
601de98eb3
commit
e9931b7896
@ -518,7 +518,11 @@ def save_results(
|
|||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_surt_error_stats(
|
wer = write_surt_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
num_channels=params.num_channels,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -190,6 +191,7 @@ class DPRNN(nn.Module):
|
|||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
num_blocks=1,
|
num_blocks=1,
|
||||||
segment_size=50,
|
segment_size=50,
|
||||||
|
chunk_width_randomization=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -198,6 +200,7 @@ class DPRNN(nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
self.segment_size = segment_size
|
self.segment_size = segment_size
|
||||||
|
self.chunk_width_randomization = chunk_width_randomization
|
||||||
|
|
||||||
self.input_embed = nn.Sequential(
|
self.input_embed = nn.Sequential(
|
||||||
ScaledLinear(feature_dim, input_size),
|
ScaledLinear(feature_dim, input_size),
|
||||||
@ -243,7 +246,11 @@ class DPRNN(nn.Module):
|
|||||||
input = self.input_embed(input)
|
input = self.input_embed(input)
|
||||||
B, T, D = input.shape
|
B, T, D = input.shape
|
||||||
|
|
||||||
input, rest = split_feature(input.transpose(1, 2), self.segment_size)
|
if self.chunk_width_randomization and self.training:
|
||||||
|
segment_size = random.randint(self.segment_size // 2, self.segment_size)
|
||||||
|
else:
|
||||||
|
segment_size = self.segment_size
|
||||||
|
input, rest = split_feature(input.transpose(1, 2), segment_size)
|
||||||
# input shape: batch, N, dim1, dim2
|
# input shape: batch, N, dim1, dim2
|
||||||
# apply RNN on dim1 first and then dim2
|
# apply RNN on dim1 first and then dim2
|
||||||
# output shape: B, output_size, dim1, dim2
|
# output shape: B, output_size, dim1, dim2
|
||||||
@ -291,6 +298,7 @@ if __name__ == "__main__":
|
|||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
num_blocks=3,
|
num_blocks=3,
|
||||||
segment_size=20,
|
segment_size=20,
|
||||||
|
chunk_width_randomization=True,
|
||||||
)
|
)
|
||||||
input = torch.randn(2, 1002, 80)
|
input = torch.randn(2, 1002, 80)
|
||||||
print(model(input).shape)
|
print(model(input).shape)
|
||||||
|
|||||||
@ -104,24 +104,31 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--num-mask-encoder-layers",
|
"--num-mask-encoder-layers",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Number of layers in the SkiM based mask encoder.",
|
help="Number of layers in the DPRNN based mask encoder.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-encoder-dim",
|
"--mask-encoder-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
help="Hidden dimension of the LSTM blocks in SkiM.",
|
help="Hidden dimension of the LSTM blocks in DPRNN.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-encoder-segment-size",
|
"--mask-encoder-segment-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=32,
|
default=32,
|
||||||
help="Segment size of the SegLSTM in SkiM. Ideally, this should be equal to the "
|
help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the "
|
||||||
"decode-chunk-length of the zipformer encoder.",
|
"decode-chunk-length of the zipformer encoder.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-width-randomization",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to randomize the chunk width in DPRNN.",
|
||||||
|
)
|
||||||
|
|
||||||
# Zipformer config is based on:
|
# Zipformer config is based on:
|
||||||
# https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740
|
# https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -508,6 +515,7 @@ def get_mask_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
output_size=params.feature_dim * params.num_channels,
|
output_size=params.feature_dim * params.num_channels,
|
||||||
segment_size=params.mask_encoder_segment_size,
|
segment_size=params.mask_encoder_segment_size,
|
||||||
num_blocks=params.num_mask_encoder_layers,
|
num_blocks=params.num_mask_encoder_layers,
|
||||||
|
chunk_width_randomization=params.chunk_width_randomization,
|
||||||
)
|
)
|
||||||
return mask_encoder
|
return mask_encoder
|
||||||
|
|
||||||
|
|||||||
140
icefall/utils.py
140
icefall/utils.py
@ -827,6 +827,7 @@ def write_surt_error_stats(
|
|||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results: List[Tuple[str, str]],
|
results: List[Tuple[str, str]],
|
||||||
enable_log: bool = True,
|
enable_log: bool = True,
|
||||||
|
num_channels: int = 2,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Write statistics based on predicted results and reference transcripts for SURT
|
"""Write statistics based on predicted results and reference transcripts for SURT
|
||||||
multi-talker ASR systems. The difference between this and the `write_error_stats`
|
multi-talker ASR systems. The difference between this and the `write_error_stats`
|
||||||
@ -867,20 +868,137 @@ def write_surt_error_stats(
|
|||||||
"""
|
"""
|
||||||
from meeteval.wer import wer
|
from meeteval.wer import wer
|
||||||
|
|
||||||
wers = []
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||||
assignments = []
|
ins: Dict[str, int] = defaultdict(int)
|
||||||
for cut_id, ref, hyp in results:
|
dels: Dict[str, int] = defaultdict(int)
|
||||||
orc_wer = wer.orc_word_error_rate(ref, hyp)
|
ref_lens: List[int] = []
|
||||||
wers.append(orc_wer)
|
|
||||||
assignments.append(orc_wer.assignment)
|
|
||||||
|
|
||||||
orc_wer = wer.combine_error_rates(*wers)
|
print(
|
||||||
tot_errs = orc_wer.errors
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||||
ref_len = orc_wer.length
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||||
tot_err_rate = orc_wer.error_rate
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
|
|
||||||
|
# `words` stores counts per word, as follows:
|
||||||
|
# corr, ref_sub, hyp_sub, ins, dels
|
||||||
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
|
num_corr = 0
|
||||||
|
ERR = "*"
|
||||||
|
for cut_id, ref, hyp in results:
|
||||||
|
# First compute the optimal assignment of references to output channels
|
||||||
|
orc_wer = wer.orc_word_error_rate(ref, hyp)
|
||||||
|
assignment = orc_wer.assignment
|
||||||
|
refs = [[] for _ in range(num_channels)]
|
||||||
|
# Assign references to channels
|
||||||
|
for i, ref_text in zip(assignment, ref):
|
||||||
|
refs[i] += ref_text.split()
|
||||||
|
hyps = [hyp_text.split() for hyp_text in hyp]
|
||||||
|
# Now compute the WER for each channel
|
||||||
|
for ref_c, hyp_c in zip(refs, hyps):
|
||||||
|
ref_lens.append(len(ref_c))
|
||||||
|
ali = kaldialign.align(ref_c, hyp_c, ERR)
|
||||||
|
for ref_word, hyp_word in ali:
|
||||||
|
if ref_word == ERR:
|
||||||
|
ins[hyp_word] += 1
|
||||||
|
words[hyp_word][3] += 1
|
||||||
|
elif hyp_word == ERR:
|
||||||
|
dels[ref_word] += 1
|
||||||
|
words[ref_word][4] += 1
|
||||||
|
elif hyp_word != ref_word:
|
||||||
|
subs[(ref_word, hyp_word)] += 1
|
||||||
|
words[ref_word][1] += 1
|
||||||
|
words[hyp_word][2] += 1
|
||||||
|
else:
|
||||||
|
words[ref_word][0] += 1
|
||||||
|
num_corr += 1
|
||||||
|
combine_successive_errors = True
|
||||||
|
if combine_successive_errors:
|
||||||
|
ali = [[[x], [y]] for x, y in ali]
|
||||||
|
for i in range(len(ali) - 1):
|
||||||
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||||
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||||
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||||
|
ali[i] = [[], []]
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
list(filter(lambda a: a != ERR, x)),
|
||||||
|
list(filter(lambda a: a != ERR, y)),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
ali = list(filter(lambda x: x != [[], []], ali))
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
ERR if x == [] else " ".join(x),
|
||||||
|
ERR if y == [] else " ".join(y),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{cut_id}:\t"
|
||||||
|
+ " ".join(
|
||||||
|
(
|
||||||
|
ref_word
|
||||||
|
if ref_word == hyp_word
|
||||||
|
else f"({ref_word}->{hyp_word})"
|
||||||
|
for ref_word, hyp_word in ali
|
||||||
|
)
|
||||||
|
),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
ref_len = sum(ref_lens)
|
||||||
|
sub_errs = sum(subs.values())
|
||||||
|
ins_errs = sum(ins.values())
|
||||||
|
del_errs = sum(dels.values())
|
||||||
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||||
|
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%}")
|
logging.info(
|
||||||
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||||
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||||
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
|
print(
|
||||||
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||||
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||||
|
f"words ({num_corr} correct)",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||||
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("DELETIONS: count ref", file=f)
|
||||||
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||||
|
print(f"{count} {ref}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("INSERTIONS: count hyp", file=f)
|
||||||
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||||
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||||
|
for _, word, counts in sorted(
|
||||||
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
|
):
|
||||||
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||||
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||||
|
ref_count = corr + ref_sub + dels
|
||||||
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
|
|
||||||
print(f"%WER = {tot_err_rate}", file=f)
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user