diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py index d05dac0f2..d797c8f61 100755 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py +++ b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py @@ -518,7 +518,11 @@ def save_results( ) with open(errs_filename, "w") as f: wer = write_surt_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + num_channels=params.num_channels, ) test_set_wers[key] = wer @@ -749,12 +753,12 @@ def main(): results_dict=results_dict, ) - # if params.save_masks: - # save_masks( - # params=params, - # test_set_name=f"dev_{ol}", - # masks=masks, - # ) + # if params.save_masks: + # save_masks( + # params=params, + # test_set_name=f"dev_{ol}", + # masks=masks, + # ) for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS): test_dl = librimix.test_dataloaders(test_set) @@ -773,12 +777,12 @@ def main(): results_dict=results_dict, ) - # if params.save_masks: - # save_masks( - # params=params, - # test_set_name=f"test_{ol}", - # masks=masks, - # ) + # if params.save_masks: + # save_masks( + # params=params, + # test_set_name=f"test_{ol}", + # masks=masks, + # ) logging.info("Done!") diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py index eeb7cb698..361b1b385 100644 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py +++ b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py @@ -1,3 +1,4 @@ +import random from typing import Optional, Tuple import torch @@ -190,6 +191,7 @@ class DPRNN(nn.Module): dropout=0.1, num_blocks=1, segment_size=50, + chunk_width_randomization=False, ): super().__init__() @@ -198,6 +200,7 @@ class DPRNN(nn.Module): self.hidden_size = hidden_size self.segment_size = segment_size + self.chunk_width_randomization = chunk_width_randomization self.input_embed = nn.Sequential( ScaledLinear(feature_dim, input_size), @@ -243,7 +246,11 @@ class DPRNN(nn.Module): input = self.input_embed(input) B, T, D = input.shape - input, rest = split_feature(input.transpose(1, 2), self.segment_size) + if self.chunk_width_randomization and self.training: + segment_size = random.randint(self.segment_size // 2, self.segment_size) + else: + segment_size = self.segment_size + input, rest = split_feature(input.transpose(1, 2), segment_size) # input shape: batch, N, dim1, dim2 # apply RNN on dim1 first and then dim2 # output shape: B, output_size, dim1, dim2 @@ -291,6 +298,7 @@ if __name__ == "__main__": dropout=0.1, num_blocks=3, segment_size=20, + chunk_width_randomization=True, ) input = torch.randn(2, 1002, 80) print(model(input).shape) diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py index 53710f79c..670ade470 100755 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py +++ b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py @@ -104,24 +104,31 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--num-mask-encoder-layers", type=int, default=4, - help="Number of layers in the SkiM based mask encoder.", + help="Number of layers in the DPRNN based mask encoder.", ) parser.add_argument( "--mask-encoder-dim", type=int, default=256, - help="Hidden dimension of the LSTM blocks in SkiM.", + help="Hidden dimension of the LSTM blocks in DPRNN.", ) parser.add_argument( "--mask-encoder-segment-size", type=int, default=32, - help="Segment size of the SegLSTM in SkiM. Ideally, this should be equal to the " + help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the " "decode-chunk-length of the zipformer encoder.", ) + parser.add_argument( + "--chunk-width-randomization", + type=bool, + default=False, + help="Whether to randomize the chunk width in DPRNN.", + ) + # Zipformer config is based on: # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740 parser.add_argument( @@ -508,6 +515,7 @@ def get_mask_encoder_model(params: AttributeDict) -> nn.Module: output_size=params.feature_dim * params.num_channels, segment_size=params.mask_encoder_segment_size, num_blocks=params.num_mask_encoder_layers, + chunk_width_randomization=params.chunk_width_randomization, ) return mask_encoder diff --git a/icefall/utils.py b/icefall/utils.py index 76f575377..49c590709 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -827,6 +827,7 @@ def write_surt_error_stats( test_set_name: str, results: List[Tuple[str, str]], enable_log: bool = True, + num_channels: int = 2, ) -> float: """Write statistics based on predicted results and reference transcripts for SURT multi-talker ASR systems. The difference between this and the `write_error_stats` @@ -867,20 +868,137 @@ def write_surt_error_stats( """ from meeteval.wer import wer - wers = [] - assignments = [] - for cut_id, ref, hyp in results: - orc_wer = wer.orc_word_error_rate(ref, hyp) - wers.append(orc_wer) - assignments.append(orc_wer.assignment) + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + ref_lens: List[int] = [] - orc_wer = wer.combine_error_rates(*wers) - tot_errs = orc_wer.errors - ref_len = orc_wer.length - tot_err_rate = orc_wer.error_rate + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + for cut_id, ref, hyp in results: + # First compute the optimal assignment of references to output channels + orc_wer = wer.orc_word_error_rate(ref, hyp) + assignment = orc_wer.assignment + refs = [[] for _ in range(num_channels)] + # Assign references to channels + for i, ref_text in zip(assignment, ref): + refs[i] += ref_text.split() + hyps = [hyp_text.split() for hyp_text in hyp] + # Now compute the WER for each channel + for ref_c, hyp_c in zip(refs, hyps): + ref_lens.append(len(ref_c)) + ali = kaldialign.align(ref_c, hyp_c, ERR) + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word + if ref_word == hyp_word + else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + ref_len = sum(ref_lens) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) if enable_log: - logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%}") + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"%WER = {tot_err_rate}", file=f) return float(tot_err_rate)