add libricss decoding

This commit is contained in:
Desh Raj 2023-02-17 21:33:09 -05:00
parent f8acb2533e
commit 601de98eb3
5 changed files with 134 additions and 106 deletions

View File

@ -334,69 +334,39 @@ class LibrimixAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self, reverberated: bool = False) -> CutSet: def train_cuts(self, reverberated: bool = False) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
if reverberated: rvb_affix = "_rvb" if reverberated else "_norvb"
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz") cs = load_manifest_lazy(
else: self.args.manifest_dir / f"cuts_train{rvb_affix}.jsonl.gz"
cs = load_manifest_lazy( )
self.args.manifest_dir / "cuts_train_norvb.jsonl.gz" # Trim to supervision groups
) cs = cs.trim_to_supervision_groups(max_pause=1.0)
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0) cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
return cs return cs
@lru_cache() @lru_cache()
def dev_cuts(self, reverberated: bool = False) -> CutSet: def dev_cuts(self, reverberated: bool = False) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
if reverberated: rvb_affix = "_rvb" if reverberated else "_norvb"
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") cs = load_manifest_lazy(
else: self.args.manifest_dir / f"cuts_dev{rvb_affix}.jsonl.gz"
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_norvb.jsonl.gz") )
cs = cs.filter(lambda c: c.duration >= 0.1) cs = cs.filter(lambda c: c.duration >= 0.1)
return cs return cs
@lru_cache() @lru_cache()
def train_cuts_2spk(self, reverberated: bool = False) -> CutSet: def train_cuts_2spk(self, reverberated: bool = False) -> CutSet:
logging.info("About to get 2-spk train cuts") logging.info("About to get 2-spk train cuts")
if reverberated: rvb_affix = "_rvb" if reverberated else "_norvb"
cs = load_manifest_lazy( cs = load_manifest_lazy(
self.args.manifest_dir / "cuts_train_2spk_reverb.jsonl.gz" self.args.manifest_dir / f"cuts_train_2spk{rvb_affix}.jsonl.gz"
) )
else: cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train_2spk.jsonl.gz")
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0)
return cs return cs
@lru_cache() @lru_cache()
def dev_cuts_2spk(self, reverberated: bool = False) -> CutSet: def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
logging.info("About to get 2-spk dev cuts") logging.info(f"About to get LibriCSS {split} {type} cuts")
if reverberated: cs = load_manifest_lazy(
cs = load_manifest_lazy( self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
self.args.manifest_dir / "cuts_dev_2spk_reverb.jsonl.gz" )
)
else:
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_2spk.jsonl.gz")
cs = cs.filter(lambda c: c.duration >= 0.1)
return cs
@lru_cache()
def train_cuts_1spk(self, reverberated: bool = False) -> CutSet:
logging.info("About to get 2-spk train cuts")
if reverberated:
cs = load_manifest_lazy(
self.args.manifest_dir / "cuts_train_1spk_reverb.jsonl.gz"
)
else:
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train_1spk.jsonl.gz")
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0)
return cs
@lru_cache()
def dev_cuts_1spk(self, reverberated: bool = False) -> CutSet:
logging.info("About to get 1-spk dev cuts")
if reverberated:
cs = load_manifest_lazy(
self.args.manifest_dir / "cuts_dev_1spk_reverb.jsonl.gz"
)
else:
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_1spk.jsonl.gz")
cs = cs.filter(lambda c: c.duration >= 0.1)
return cs return cs

View File

@ -94,6 +94,8 @@ from icefall.utils import (
write_surt_error_stats, write_surt_error_stats,
) )
OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"]
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -256,6 +258,13 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--save-masks",
type=str2bool,
default=False,
help="""If true, save masks generated by unmixing module.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -319,6 +328,22 @@ def decode_one_batch(
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0) h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
def _group_channels(hyps: List[str]) -> List[List[str]]:
"""
Currently we have a batch of size M*B, where M is the number of
channels and B is the batch size. We need to group the hypotheses
into B groups, each of which contains M hypotheses.
Example:
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
"""
assert len(hyps) == B * params.num_channels
out_hyps = []
for i in range(B):
out_hyps.append(hyps[i::B])
return out_hyps
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -331,7 +356,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp)
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
@ -339,7 +364,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp)
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -348,7 +373,7 @@ def decode_one_batch(
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp)
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -372,10 +397,10 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp))
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": _group_channels(hyps)}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
@ -386,9 +411,9 @@ def decode_one_batch(
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: _group_channels(hyps)}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": _group_channels(hyps)}
def decode_dataset( def decode_dataset(
@ -437,14 +462,8 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
# The dataloader returns text as a list of cuts, each of which is a list of channel
# text. We flatten this to a list where all channels are together, i.e., it looks like
# [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2].
texts = [val for tup in zip(*batch["text"]) for val in tup]
cut_ids = [cut.id for cut in batch["cuts"]] cut_ids = [cut.id for cut in batch["cuts"]]
cuts_batch = batch["cuts"]
# Repeat cut_ids list N times, where N is the number of channels.
cut_ids = list(chain.from_iterable(repeat(cut_ids, params.num_channels)))
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -457,14 +476,19 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts), f"{len(hyps)} vs {len(texts)}" for cut_id, hyp_words in zip(cut_ids, hyps):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): # Reference is a list of supervision texts sorted by start time.
ref_words = ref_text.split() ref_words = [
s.text.strip()
for s in sorted(
cuts_batch[cut_id].supervisions, key=lambda s: s.start
)
]
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
num_cuts += len(texts) num_cuts += len(cut_ids)
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
@ -484,19 +508,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
# Combine results by cut_id. This means that we combine different channels for store_transcripts(filename=recog_path, texts=results)
# ref and hyp of the same cut into list. Example:
# (cut1, ref1, hyp1), (cut1, ref2, hyp2), (cut2, ref3, hyp3) ->
# (cut1, [ref1, ref2], [hyp1, hyp2]), (cut2, [ref3], [hyp3])
# Also, each ref and hyp is currently a list of words. We join them into a string.
results_grouped = []
for cut_id, items in groupby(results, lambda x: x[0]):
items = list(items)
refs = [" ".join(item[1]) for item in items]
hyps = [" ".join(item[2]) for item in items]
results_grouped.append((cut_id, refs, hyps))
store_transcripts(filename=recog_path, texts=results_grouped)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -506,7 +518,7 @@ def save_results(
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_surt_error_stats( wer = write_surt_error_stats(
f, f"{test_set_name}-{key}", results_grouped, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
@ -529,6 +541,16 @@ def save_results(
logging.info(s) logging.info(s)
def save_masks(
params: AttributeDict,
test_set_name: str,
masks: List[torch.Tensor],
):
masks_path = params.res_dir / f"masks-{test_set_name}.txt"
torch.save(masks, masks_path)
logging.info(f"The masks are stored in {masks_path}")
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
@ -703,15 +725,39 @@ def main():
args.return_cuts = True args.return_cuts = True
librimix = LibrimixAsrDataModule(args) librimix = LibrimixAsrDataModule(args)
dev_2spk_cuts = librimix.dev_cuts_2spk() dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix").to_eager()
dev_2spk_dl = librimix.test_dataloaders(dev_2spk_cuts) dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS]
train_2spk_cuts = librimix.train_cuts_2spk(sp=None) test_cuts = librimix.libricss_cuts(split="test", type="ihm-mix").to_eager()
train_2spk_dl = librimix.test_dataloaders(train_2spk_cuts) test_cuts_grouped = [
test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS
]
test_sets = ["dev_2spk", "train_2spk"] for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS):
test_dl = [dev_2spk_dl, train_2spk_dl] dev_dl = librimix.test_dataloaders(dev_set)
results_dict = decode_dataset(
dl=dev_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
for test_set, test_dl in zip(test_sets, test_dl): save_results(
params=params,
test_set_name=f"dev_{ol}",
results_dict=results_dict,
)
# if params.save_masks:
# save_masks(
# params=params,
# test_set_name=f"dev_{ol}",
# masks=masks,
# )
for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
test_dl = librimix.test_dataloaders(test_set)
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
@ -723,10 +769,17 @@ def main():
save_results( save_results(
params=params, params=params,
test_set_name=test_set, test_set_name=f"test_{ol}",
results_dict=results_dict, results_dict=results_dict,
) )
# if params.save_masks:
# save_masks(
# params=params,
# test_set_name=f"test_{ol}",
# masks=masks,
# )
logging.info("Done!") logging.info("Done!")

View File

@ -236,7 +236,6 @@ class DPRNN(nn.Module):
min_positive=0.45, min_positive=0.45,
max_positive=0.55, max_positive=0.55,
), ),
nn.ReLU(inplace=True),
) )
def forward(self, input): def forward(self, input):
@ -276,6 +275,9 @@ class DPRNN(nn.Module):
output = output.transpose(1, 2) output = output.transpose(1, 2)
output = self.out_embed(output) output = self.out_embed(output)
# Apply ReLU to the output
output = torch.relu(output)
return output return output

View File

@ -303,7 +303,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=0.04, help="The base learning rate." "--base-lr", type=float, default=0.004, help="The base learning rate."
) )
parser.add_argument( parser.add_argument(
@ -747,6 +747,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
reduction="none", reduction="none",
subsampling_factor=params.subsampling_factor,
) )
simple_loss_is_finite = torch.isfinite(simple_loss) simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)
@ -1136,9 +1137,8 @@ def run(rank, world_size, args):
if checkpoints is None and params.encoder_init_ckpt is not None: if checkpoints is None and params.encoder_init_ckpt is not None:
logging.info("Initializing encoder with checkpoint") logging.info("Initializing encoder with checkpoint")
model.encoder.load_state_dict( init_ckpt = torch.load(params.encoder_init_ckpt, map_location=device)
torch.load(params.encoder_init_ckpt, map_location=device) model.load_state_dict(init_ckpt["model"], strict=False)
)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
@ -1176,8 +1176,8 @@ def run(rank, world_size, args):
train_cuts = librimix.train_cuts(reverberated=False) train_cuts = librimix.train_cuts(reverberated=False)
train_cuts_2spk = librimix.train_cuts_2spk(reverberated=False) train_cuts_2spk = librimix.train_cuts_2spk(reverberated=False)
# train_cuts_1spk = librimix.train_cuts_1spk(sp=sp) # dev_cuts = librimix.dev_cuts(reverberated=False)
dev_cuts = librimix.dev_cuts(reverberated=False) dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix")
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
@ -1191,7 +1191,6 @@ def run(rank, world_size, args):
sampler_state_dict=sampler_state_dict, sampler_state_dict=sampler_state_dict,
) )
train_dl_2spk = librimix.train_dataloaders(train_cuts_2spk) train_dl_2spk = librimix.train_dataloaders(train_cuts_2spk)
# train_dl = librimix.train_dataloaders(train_cuts_1spk)
valid_dl = librimix.valid_dataloaders(dev_cuts) valid_dl = librimix.valid_dataloaders(dev_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)

View File

@ -120,8 +120,9 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
sim_cmd="queue.pl --mem 16G -l 'num_proc=4,h_rt=600:00:00'" sim_cmd="queue.pl --mem 16G -l 'num_proc=4,h_rt=600:00:00'"
# gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\ gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
# grep -v "0L" | gzip -c > data/manifests/libricss-sdm_supervisions_all_no0L.jsonl.gz grep -v "0L" | grep -v "OV10" | grep -v "OV20" |\
gzip -c > data/manifests/libricss-sdm_supervisions_all_v2.jsonl.gz
# 2-speaker anechoic # 2-speaker anechoic
# log "Generating 2-speaker anechoic training set" # log "Generating 2-speaker anechoic training set"
@ -152,7 +153,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
# data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \ # data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
# data/manifests/libri-mix_cuts_train_2spk_rvb.jsonl.gz # data/manifests/libri-mix_cuts_train_2spk_rvb.jsonl.gz
# Full training set (2,3,4 speakers) anechoic # Full training set (2,3 speakers) anechoic
for part in train; do for part in train; do
if [ $part == "dev" ]; then if [ $part == "dev" ]; then
num_jobs=1 num_jobs=1
@ -162,11 +163,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Generating anechoic ${part} set (full)" log "Generating anechoic ${part} set (full)"
$sim_cmd exp/sim_${part}.log lhotse workflows simulate-meetings \ $sim_cmd exp/sim_${part}.log lhotse workflows simulate-meetings \
--method conversational \ --method conversational \
--fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_no0L.jsonl.gz \
--num-repeats 1 \ --num-repeats 1 \
--num-speakers-per-meeting 2,3,4 \ --same-spk-pause 0.5 \
--max-duration-per-speaker 20.0 \ --diff-spk-pause 0.5 \
--max-utterances-per-speaker 4 \ --diff-spk-overlap 2 \
--prob-diff-spk-overlap 0.75 \
--num-speakers-per-meeting 2,3 \
--max-duration-per-speaker 15.0 \
--max-utterances-per-speaker 3 \
--seed 1234 \ --seed 1234 \
--num-jobs ${num_jobs} \ --num-jobs ${num_jobs} \
data/manifests/librispeech_cuts_${part}_trimmed.jsonl.gz \ data/manifests/librispeech_cuts_${part}_trimmed.jsonl.gz \