From 601de98eb3afbdd769ceeddfd8e9f70ef271ad03 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Fri, 17 Feb 2023 21:33:09 -0500 Subject: [PATCH] add libricss decoding --- .../asr_datamodule.py | 72 +++------- .../decode_libricss.py | 133 ++++++++++++------ .../dprnn.py | 4 +- .../train.py | 13 +- egs/libricss/SURT/prepare.sh | 18 ++- 5 files changed, 134 insertions(+), 106 deletions(-) diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/asr_datamodule.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/asr_datamodule.py index e8a2cfb05..a824cda45 100644 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/asr_datamodule.py @@ -334,69 +334,39 @@ class LibrimixAsrDataModule: @lru_cache() def train_cuts(self, reverberated: bool = False) -> CutSet: logging.info("About to get train cuts") - if reverberated: - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz") - else: - cs = load_manifest_lazy( - self.args.manifest_dir / "cuts_train_norvb.jsonl.gz" - ) - cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0) + rvb_affix = "_rvb" if reverberated else "_norvb" + cs = load_manifest_lazy( + self.args.manifest_dir / f"cuts_train{rvb_affix}.jsonl.gz" + ) + # Trim to supervision groups + cs = cs.trim_to_supervision_groups(max_pause=1.0) + cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0) return cs @lru_cache() def dev_cuts(self, reverberated: bool = False) -> CutSet: logging.info("About to get dev cuts") - if reverberated: - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") - else: - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_norvb.jsonl.gz") + rvb_affix = "_rvb" if reverberated else "_norvb" + cs = load_manifest_lazy( + self.args.manifest_dir / f"cuts_dev{rvb_affix}.jsonl.gz" + ) cs = cs.filter(lambda c: c.duration >= 0.1) return cs @lru_cache() def train_cuts_2spk(self, reverberated: bool = False) -> CutSet: logging.info("About to get 2-spk train cuts") - if reverberated: - cs = load_manifest_lazy( - self.args.manifest_dir / "cuts_train_2spk_reverb.jsonl.gz" - ) - else: - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train_2spk.jsonl.gz") - cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0) + rvb_affix = "_rvb" if reverberated else "_norvb" + cs = load_manifest_lazy( + self.args.manifest_dir / f"cuts_train_2spk{rvb_affix}.jsonl.gz" + ) + cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0) return cs @lru_cache() - def dev_cuts_2spk(self, reverberated: bool = False) -> CutSet: - logging.info("About to get 2-spk dev cuts") - if reverberated: - cs = load_manifest_lazy( - self.args.manifest_dir / "cuts_dev_2spk_reverb.jsonl.gz" - ) - else: - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_2spk.jsonl.gz") - cs = cs.filter(lambda c: c.duration >= 0.1) - return cs - - @lru_cache() - def train_cuts_1spk(self, reverberated: bool = False) -> CutSet: - logging.info("About to get 2-spk train cuts") - if reverberated: - cs = load_manifest_lazy( - self.args.manifest_dir / "cuts_train_1spk_reverb.jsonl.gz" - ) - else: - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train_1spk.jsonl.gz") - cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0) - return cs - - @lru_cache() - def dev_cuts_1spk(self, reverberated: bool = False) -> CutSet: - logging.info("About to get 1-spk dev cuts") - if reverberated: - cs = load_manifest_lazy( - self.args.manifest_dir / "cuts_dev_1spk_reverb.jsonl.gz" - ) - else: - cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_1spk.jsonl.gz") - cs = cs.filter(lambda c: c.duration >= 0.1) + def libricss_cuts(self, split="dev", type="sdm") -> CutSet: + logging.info(f"About to get LibriCSS {split} {type} cuts") + cs = load_manifest_lazy( + self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz" + ) return cs diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py index df707bbc3..d05dac0f2 100755 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py +++ b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py @@ -94,6 +94,8 @@ from icefall.utils import ( write_surt_error_stats, ) +OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"] + def get_parser(): parser = argparse.ArgumentParser( @@ -256,6 +258,13 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--save-masks", + type=str2bool, + default=False, + help="""If true, save masks generated by unmixing module.""", + ) + add_model_arguments(parser) return parser @@ -319,6 +328,22 @@ def decode_one_batch( h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0) encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) + def _group_channels(hyps: List[str]) -> List[List[str]]: + """ + Currently we have a batch of size M*B, where M is the number of + channels and B is the batch size. We need to group the hypotheses + into B groups, each of which contains M hypotheses. + + Example: + hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2'] + _group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']] + """ + assert len(hyps) == B * params.num_channels + out_hyps = [] + for i in range(B): + out_hyps.append(hyps[i::B]) + return out_hyps + hyps = [] if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -331,7 +356,7 @@ def decode_one_batch( max_states=params.max_states, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(hyp) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, @@ -339,7 +364,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(hyp) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -348,7 +373,7 @@ def decode_one_batch( beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(hyp) else: batch_size = encoder_out.size(0) @@ -372,10 +397,10 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + hyps.append(sp.decode(hyp)) if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search": _group_channels(hyps)} elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -386,9 +411,9 @@ def decode_one_batch( if "LG" in params.decoding_method: key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - return {key: hyps} + return {key: _group_channels(hyps)} else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": _group_channels(hyps)} def decode_dataset( @@ -437,14 +462,8 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - # The dataloader returns text as a list of cuts, each of which is a list of channel - # text. We flatten this to a list where all channels are together, i.e., it looks like - # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. - texts = [val for tup in zip(*batch["text"]) for val in tup] cut_ids = [cut.id for cut in batch["cuts"]] - - # Repeat cut_ids list N times, where N is the number of channels. - cut_ids = list(chain.from_iterable(repeat(cut_ids, params.num_channels))) + cuts_batch = batch["cuts"] hyps_dict = decode_one_batch( params=params, @@ -457,14 +476,19 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] - assert len(hyps) == len(texts), f"{len(hyps)} vs {len(texts)}" - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() + for cut_id, hyp_words in zip(cut_ids, hyps): + # Reference is a list of supervision texts sorted by start time. + ref_words = [ + s.text.strip() + for s in sorted( + cuts_batch[cut_id].supervisions, key=lambda s: s.start + ) + ] this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) - num_cuts += len(texts) + num_cuts += len(cut_ids) if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" @@ -484,19 +508,7 @@ def save_results( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - # Combine results by cut_id. This means that we combine different channels for - # ref and hyp of the same cut into list. Example: - # (cut1, ref1, hyp1), (cut1, ref2, hyp2), (cut2, ref3, hyp3) -> - # (cut1, [ref1, ref2], [hyp1, hyp2]), (cut2, [ref3], [hyp3]) - # Also, each ref and hyp is currently a list of words. We join them into a string. - results_grouped = [] - for cut_id, items in groupby(results, lambda x: x[0]): - items = list(items) - refs = [" ".join(item[1]) for item in items] - hyps = [" ".join(item[2]) for item in items] - results_grouped.append((cut_id, refs, hyps)) - - store_transcripts(filename=recog_path, texts=results_grouped) + store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -506,7 +518,7 @@ def save_results( ) with open(errs_filename, "w") as f: wer = write_surt_error_stats( - f, f"{test_set_name}-{key}", results_grouped, enable_log=True + f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer @@ -529,6 +541,16 @@ def save_results( logging.info(s) +def save_masks( + params: AttributeDict, + test_set_name: str, + masks: List[torch.Tensor], +): + masks_path = params.res_dir / f"masks-{test_set_name}.txt" + torch.save(masks, masks_path) + logging.info(f"The masks are stored in {masks_path}") + + @torch.no_grad() def main(): parser = get_parser() @@ -703,15 +725,39 @@ def main(): args.return_cuts = True librimix = LibrimixAsrDataModule(args) - dev_2spk_cuts = librimix.dev_cuts_2spk() - dev_2spk_dl = librimix.test_dataloaders(dev_2spk_cuts) - train_2spk_cuts = librimix.train_cuts_2spk(sp=None) - train_2spk_dl = librimix.test_dataloaders(train_2spk_cuts) + dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix").to_eager() + dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS] + test_cuts = librimix.libricss_cuts(split="test", type="ihm-mix").to_eager() + test_cuts_grouped = [ + test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS + ] - test_sets = ["dev_2spk", "train_2spk"] - test_dl = [dev_2spk_dl, train_2spk_dl] + for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS): + dev_dl = librimix.test_dataloaders(dev_set) + results_dict = decode_dataset( + dl=dev_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) - for test_set, test_dl in zip(test_sets, test_dl): + save_results( + params=params, + test_set_name=f"dev_{ol}", + results_dict=results_dict, + ) + + # if params.save_masks: + # save_masks( + # params=params, + # test_set_name=f"dev_{ol}", + # masks=masks, + # ) + + for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS): + test_dl = librimix.test_dataloaders(test_set) results_dict = decode_dataset( dl=test_dl, params=params, @@ -723,10 +769,17 @@ def main(): save_results( params=params, - test_set_name=test_set, + test_set_name=f"test_{ol}", results_dict=results_dict, ) + # if params.save_masks: + # save_masks( + # params=params, + # test_set_name=f"test_{ol}", + # masks=masks, + # ) + logging.info("Done!") diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py index 31180e687..eeb7cb698 100644 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py +++ b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py @@ -236,7 +236,6 @@ class DPRNN(nn.Module): min_positive=0.45, max_positive=0.55, ), - nn.ReLU(inplace=True), ) def forward(self, input): @@ -276,6 +275,9 @@ class DPRNN(nn.Module): output = output.transpose(1, 2) output = self.out_embed(output) + # Apply ReLU to the output + output = torch.relu(output) + return output diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py index 875d8875d..53710f79c 100755 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py +++ b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py @@ -303,7 +303,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.04, help="The base learning rate." + "--base-lr", type=float, default=0.004, help="The base learning rate." ) parser.add_argument( @@ -747,6 +747,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, reduction="none", + subsampling_factor=params.subsampling_factor, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) @@ -1136,9 +1137,8 @@ def run(rank, world_size, args): if checkpoints is None and params.encoder_init_ckpt is not None: logging.info("Initializing encoder with checkpoint") - model.encoder.load_state_dict( - torch.load(params.encoder_init_ckpt, map_location=device) - ) + init_ckpt = torch.load(params.encoder_init_ckpt, map_location=device) + model.load_state_dict(init_ckpt["model"], strict=False) if world_size > 1: logging.info("Using DDP") @@ -1176,8 +1176,8 @@ def run(rank, world_size, args): train_cuts = librimix.train_cuts(reverberated=False) train_cuts_2spk = librimix.train_cuts_2spk(reverberated=False) - # train_cuts_1spk = librimix.train_cuts_1spk(sp=sp) - dev_cuts = librimix.dev_cuts(reverberated=False) + # dev_cuts = librimix.dev_cuts(reverberated=False) + dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix") if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1191,7 +1191,6 @@ def run(rank, world_size, args): sampler_state_dict=sampler_state_dict, ) train_dl_2spk = librimix.train_dataloaders(train_cuts_2spk) - # train_dl = librimix.train_dataloaders(train_cuts_1spk) valid_dl = librimix.valid_dataloaders(dev_cuts) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) diff --git a/egs/libricss/SURT/prepare.sh b/egs/libricss/SURT/prepare.sh index cfa3ce056..482b18a21 100755 --- a/egs/libricss/SURT/prepare.sh +++ b/egs/libricss/SURT/prepare.sh @@ -120,8 +120,9 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then sim_cmd="queue.pl --mem 16G -l 'num_proc=4,h_rt=600:00:00'" - # gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\ - # grep -v "0L" | gzip -c > data/manifests/libricss-sdm_supervisions_all_no0L.jsonl.gz + gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\ + grep -v "0L" | grep -v "OV10" | grep -v "OV20" |\ + gzip -c > data/manifests/libricss-sdm_supervisions_all_v2.jsonl.gz # 2-speaker anechoic # log "Generating 2-speaker anechoic training set" @@ -152,7 +153,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then # data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \ # data/manifests/libri-mix_cuts_train_2spk_rvb.jsonl.gz - # Full training set (2,3,4 speakers) anechoic + # Full training set (2,3 speakers) anechoic for part in train; do if [ $part == "dev" ]; then num_jobs=1 @@ -162,11 +163,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Generating anechoic ${part} set (full)" $sim_cmd exp/sim_${part}.log lhotse workflows simulate-meetings \ --method conversational \ - --fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_no0L.jsonl.gz \ --num-repeats 1 \ - --num-speakers-per-meeting 2,3,4 \ - --max-duration-per-speaker 20.0 \ - --max-utterances-per-speaker 4 \ + --same-spk-pause 0.5 \ + --diff-spk-pause 0.5 \ + --diff-spk-overlap 2 \ + --prob-diff-spk-overlap 0.75 \ + --num-speakers-per-meeting 2,3 \ + --max-duration-per-speaker 15.0 \ + --max-utterances-per-speaker 3 \ --seed 1234 \ --num-jobs ${num_jobs} \ data/manifests/librispeech_cuts_${part}_trimmed.jsonl.gz \