mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add libricss decoding
This commit is contained in:
parent
f8acb2533e
commit
601de98eb3
@ -334,69 +334,39 @@ class LibrimixAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self, reverberated: bool = False) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
if reverberated:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz")
|
||||
else:
|
||||
rvb_affix = "_rvb" if reverberated else "_norvb"
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_norvb.jsonl.gz"
|
||||
self.args.manifest_dir / f"cuts_train{rvb_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0)
|
||||
# Trim to supervision groups
|
||||
cs = cs.trim_to_supervision_groups(max_pause=1.0)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self, reverberated: bool = False) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
if reverberated:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
||||
else:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_norvb.jsonl.gz")
|
||||
rvb_affix = "_rvb" if reverberated else "_norvb"
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_dev{rvb_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 0.1)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_2spk(self, reverberated: bool = False) -> CutSet:
|
||||
logging.info("About to get 2-spk train cuts")
|
||||
if reverberated:
|
||||
rvb_affix = "_rvb" if reverberated else "_norvb"
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_2spk_reverb.jsonl.gz"
|
||||
self.args.manifest_dir / f"cuts_train_2spk{rvb_affix}.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train_2spk.jsonl.gz")
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts_2spk(self, reverberated: bool = False) -> CutSet:
|
||||
logging.info("About to get 2-spk dev cuts")
|
||||
if reverberated:
|
||||
def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
|
||||
logging.info(f"About to get LibriCSS {split} {type} cuts")
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_dev_2spk_reverb.jsonl.gz"
|
||||
self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_2spk.jsonl.gz")
|
||||
cs = cs.filter(lambda c: c.duration >= 0.1)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_1spk(self, reverberated: bool = False) -> CutSet:
|
||||
logging.info("About to get 2-spk train cuts")
|
||||
if reverberated:
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_1spk_reverb.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train_1spk.jsonl.gz")
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts_1spk(self, reverberated: bool = False) -> CutSet:
|
||||
logging.info("About to get 1-spk dev cuts")
|
||||
if reverberated:
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_dev_1spk_reverb.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_1spk.jsonl.gz")
|
||||
cs = cs.filter(lambda c: c.duration >= 0.1)
|
||||
return cs
|
||||
|
||||
@ -94,6 +94,8 @@ from icefall.utils import (
|
||||
write_surt_error_stats,
|
||||
)
|
||||
|
||||
OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -256,6 +258,13 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-masks",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If true, save masks generated by unmixing module.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -319,6 +328,22 @@ def decode_one_batch(
|
||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||
|
||||
def _group_channels(hyps: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Currently we have a batch of size M*B, where M is the number of
|
||||
channels and B is the batch size. We need to group the hypotheses
|
||||
into B groups, each of which contains M hypotheses.
|
||||
|
||||
Example:
|
||||
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
|
||||
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
|
||||
"""
|
||||
assert len(hyps) == B * params.num_channels
|
||||
out_hyps = []
|
||||
for i in range(B):
|
||||
out_hyps.append(hyps[i::B])
|
||||
return out_hyps
|
||||
|
||||
hyps = []
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
@ -331,7 +356,7 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
@ -339,7 +364,7 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
@ -348,7 +373,7 @@ def decode_one_batch(
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(hyp)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -372,10 +397,10 @@ def decode_one_batch(
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
hyps.append(sp.decode(hyp))
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search": _group_channels(hyps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -386,9 +411,9 @@ def decode_one_batch(
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
return {key: _group_channels(hyps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -437,14 +462,8 @@ def decode_dataset(
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
# The dataloader returns text as a list of cuts, each of which is a list of channel
|
||||
# text. We flatten this to a list where all channels are together, i.e., it looks like
|
||||
# [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2].
|
||||
texts = [val for tup in zip(*batch["text"]) for val in tup]
|
||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||
|
||||
# Repeat cut_ids list N times, where N is the number of channels.
|
||||
cut_ids = list(chain.from_iterable(repeat(cut_ids, params.num_channels)))
|
||||
cuts_batch = batch["cuts"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -457,14 +476,19 @@ def decode_dataset(
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts), f"{len(hyps)} vs {len(texts)}"
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
for cut_id, hyp_words in zip(cut_ids, hyps):
|
||||
# Reference is a list of supervision texts sorted by start time.
|
||||
ref_words = [
|
||||
s.text.strip()
|
||||
for s in sorted(
|
||||
cuts_batch[cut_id].supervisions, key=lambda s: s.start
|
||||
)
|
||||
]
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
num_cuts += len(cut_ids)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
@ -484,19 +508,7 @@ def save_results(
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
# Combine results by cut_id. This means that we combine different channels for
|
||||
# ref and hyp of the same cut into list. Example:
|
||||
# (cut1, ref1, hyp1), (cut1, ref2, hyp2), (cut2, ref3, hyp3) ->
|
||||
# (cut1, [ref1, ref2], [hyp1, hyp2]), (cut2, [ref3], [hyp3])
|
||||
# Also, each ref and hyp is currently a list of words. We join them into a string.
|
||||
results_grouped = []
|
||||
for cut_id, items in groupby(results, lambda x: x[0]):
|
||||
items = list(items)
|
||||
refs = [" ".join(item[1]) for item in items]
|
||||
hyps = [" ".join(item[2]) for item in items]
|
||||
results_grouped.append((cut_id, refs, hyps))
|
||||
|
||||
store_transcripts(filename=recog_path, texts=results_grouped)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -506,7 +518,7 @@ def save_results(
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_surt_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_grouped, enable_log=True
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
@ -529,6 +541,16 @@ def save_results(
|
||||
logging.info(s)
|
||||
|
||||
|
||||
def save_masks(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
masks: List[torch.Tensor],
|
||||
):
|
||||
masks_path = params.res_dir / f"masks-{test_set_name}.txt"
|
||||
torch.save(masks, masks_path)
|
||||
logging.info(f"The masks are stored in {masks_path}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
@ -703,15 +725,39 @@ def main():
|
||||
args.return_cuts = True
|
||||
librimix = LibrimixAsrDataModule(args)
|
||||
|
||||
dev_2spk_cuts = librimix.dev_cuts_2spk()
|
||||
dev_2spk_dl = librimix.test_dataloaders(dev_2spk_cuts)
|
||||
train_2spk_cuts = librimix.train_cuts_2spk(sp=None)
|
||||
train_2spk_dl = librimix.test_dataloaders(train_2spk_cuts)
|
||||
dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix").to_eager()
|
||||
dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS]
|
||||
test_cuts = librimix.libricss_cuts(split="test", type="ihm-mix").to_eager()
|
||||
test_cuts_grouped = [
|
||||
test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS
|
||||
]
|
||||
|
||||
test_sets = ["dev_2spk", "train_2spk"]
|
||||
test_dl = [dev_2spk_dl, train_2spk_dl]
|
||||
for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS):
|
||||
dev_dl = librimix.test_dataloaders(dev_set)
|
||||
results_dict = decode_dataset(
|
||||
dl=dev_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=f"dev_{ol}",
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
# if params.save_masks:
|
||||
# save_masks(
|
||||
# params=params,
|
||||
# test_set_name=f"dev_{ol}",
|
||||
# masks=masks,
|
||||
# )
|
||||
|
||||
for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
|
||||
test_dl = librimix.test_dataloaders(test_set)
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
@ -723,10 +769,17 @@ def main():
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
test_set_name=f"test_{ol}",
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
# if params.save_masks:
|
||||
# save_masks(
|
||||
# params=params,
|
||||
# test_set_name=f"test_{ol}",
|
||||
# masks=masks,
|
||||
# )
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
||||
@ -236,7 +236,6 @@ class DPRNN(nn.Module):
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@ -276,6 +275,9 @@ class DPRNN(nn.Module):
|
||||
output = output.transpose(1, 2)
|
||||
output = self.out_embed(output)
|
||||
|
||||
# Apply ReLU to the output
|
||||
output = torch.relu(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@ -303,7 +303,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=0.04, help="The base learning rate."
|
||||
"--base-lr", type=float, default=0.004, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -747,6 +747,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
reduction="none",
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
@ -1136,9 +1137,8 @@ def run(rank, world_size, args):
|
||||
|
||||
if checkpoints is None and params.encoder_init_ckpt is not None:
|
||||
logging.info("Initializing encoder with checkpoint")
|
||||
model.encoder.load_state_dict(
|
||||
torch.load(params.encoder_init_ckpt, map_location=device)
|
||||
)
|
||||
init_ckpt = torch.load(params.encoder_init_ckpt, map_location=device)
|
||||
model.load_state_dict(init_ckpt["model"], strict=False)
|
||||
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
@ -1176,8 +1176,8 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = librimix.train_cuts(reverberated=False)
|
||||
train_cuts_2spk = librimix.train_cuts_2spk(reverberated=False)
|
||||
# train_cuts_1spk = librimix.train_cuts_1spk(sp=sp)
|
||||
dev_cuts = librimix.dev_cuts(reverberated=False)
|
||||
# dev_cuts = librimix.dev_cuts(reverberated=False)
|
||||
dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix")
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
@ -1191,7 +1191,6 @@ def run(rank, world_size, args):
|
||||
sampler_state_dict=sampler_state_dict,
|
||||
)
|
||||
train_dl_2spk = librimix.train_dataloaders(train_cuts_2spk)
|
||||
# train_dl = librimix.train_dataloaders(train_cuts_1spk)
|
||||
valid_dl = librimix.valid_dataloaders(dev_cuts)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
|
||||
@ -120,8 +120,9 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
|
||||
sim_cmd="queue.pl --mem 16G -l 'num_proc=4,h_rt=600:00:00'"
|
||||
|
||||
# gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
|
||||
# grep -v "0L" | gzip -c > data/manifests/libricss-sdm_supervisions_all_no0L.jsonl.gz
|
||||
gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
|
||||
grep -v "0L" | grep -v "OV10" | grep -v "OV20" |\
|
||||
gzip -c > data/manifests/libricss-sdm_supervisions_all_v2.jsonl.gz
|
||||
|
||||
# 2-speaker anechoic
|
||||
# log "Generating 2-speaker anechoic training set"
|
||||
@ -152,7 +153,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
# data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
|
||||
# data/manifests/libri-mix_cuts_train_2spk_rvb.jsonl.gz
|
||||
|
||||
# Full training set (2,3,4 speakers) anechoic
|
||||
# Full training set (2,3 speakers) anechoic
|
||||
for part in train; do
|
||||
if [ $part == "dev" ]; then
|
||||
num_jobs=1
|
||||
@ -162,11 +163,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Generating anechoic ${part} set (full)"
|
||||
$sim_cmd exp/sim_${part}.log lhotse workflows simulate-meetings \
|
||||
--method conversational \
|
||||
--fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_no0L.jsonl.gz \
|
||||
--num-repeats 1 \
|
||||
--num-speakers-per-meeting 2,3,4 \
|
||||
--max-duration-per-speaker 20.0 \
|
||||
--max-utterances-per-speaker 4 \
|
||||
--same-spk-pause 0.5 \
|
||||
--diff-spk-pause 0.5 \
|
||||
--diff-spk-overlap 2 \
|
||||
--prob-diff-spk-overlap 0.75 \
|
||||
--num-speakers-per-meeting 2,3 \
|
||||
--max-duration-per-speaker 15.0 \
|
||||
--max-utterances-per-speaker 3 \
|
||||
--seed 1234 \
|
||||
--num-jobs ${num_jobs} \
|
||||
data/manifests/librispeech_cuts_${part}_trimmed.jsonl.gz \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user