mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add libricss decoding
This commit is contained in:
parent
f8acb2533e
commit
601de98eb3
@ -334,69 +334,39 @@ class LibrimixAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self, reverberated: bool = False) -> CutSet:
|
def train_cuts(self, reverberated: bool = False) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
if reverberated:
|
rvb_affix = "_rvb" if reverberated else "_norvb"
|
||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz")
|
cs = load_manifest_lazy(
|
||||||
else:
|
self.args.manifest_dir / f"cuts_train{rvb_affix}.jsonl.gz"
|
||||||
cs = load_manifest_lazy(
|
)
|
||||||
self.args.manifest_dir / "cuts_train_norvb.jsonl.gz"
|
# Trim to supervision groups
|
||||||
)
|
cs = cs.trim_to_supervision_groups(max_pause=1.0)
|
||||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0)
|
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||||
return cs
|
return cs
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_cuts(self, reverberated: bool = False) -> CutSet:
|
def dev_cuts(self, reverberated: bool = False) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
if reverberated:
|
rvb_affix = "_rvb" if reverberated else "_norvb"
|
||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
cs = load_manifest_lazy(
|
||||||
else:
|
self.args.manifest_dir / f"cuts_dev{rvb_affix}.jsonl.gz"
|
||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_norvb.jsonl.gz")
|
)
|
||||||
cs = cs.filter(lambda c: c.duration >= 0.1)
|
cs = cs.filter(lambda c: c.duration >= 0.1)
|
||||||
return cs
|
return cs
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts_2spk(self, reverberated: bool = False) -> CutSet:
|
def train_cuts_2spk(self, reverberated: bool = False) -> CutSet:
|
||||||
logging.info("About to get 2-spk train cuts")
|
logging.info("About to get 2-spk train cuts")
|
||||||
if reverberated:
|
rvb_affix = "_rvb" if reverberated else "_norvb"
|
||||||
cs = load_manifest_lazy(
|
cs = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_train_2spk_reverb.jsonl.gz"
|
self.args.manifest_dir / f"cuts_train_2spk{rvb_affix}.jsonl.gz"
|
||||||
)
|
)
|
||||||
else:
|
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train_2spk.jsonl.gz")
|
|
||||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0)
|
|
||||||
return cs
|
return cs
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_cuts_2spk(self, reverberated: bool = False) -> CutSet:
|
def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
|
||||||
logging.info("About to get 2-spk dev cuts")
|
logging.info(f"About to get LibriCSS {split} {type} cuts")
|
||||||
if reverberated:
|
cs = load_manifest_lazy(
|
||||||
cs = load_manifest_lazy(
|
self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
|
||||||
self.args.manifest_dir / "cuts_dev_2spk_reverb.jsonl.gz"
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_2spk.jsonl.gz")
|
|
||||||
cs = cs.filter(lambda c: c.duration >= 0.1)
|
|
||||||
return cs
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_cuts_1spk(self, reverberated: bool = False) -> CutSet:
|
|
||||||
logging.info("About to get 2-spk train cuts")
|
|
||||||
if reverberated:
|
|
||||||
cs = load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "cuts_train_1spk_reverb.jsonl.gz"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_train_1spk.jsonl.gz")
|
|
||||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 50.0)
|
|
||||||
return cs
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def dev_cuts_1spk(self, reverberated: bool = False) -> CutSet:
|
|
||||||
logging.info("About to get 1-spk dev cuts")
|
|
||||||
if reverberated:
|
|
||||||
cs = load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "cuts_dev_1spk_reverb.jsonl.gz"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_1spk.jsonl.gz")
|
|
||||||
cs = cs.filter(lambda c: c.duration >= 0.1)
|
|
||||||
return cs
|
return cs
|
||||||
|
|||||||
@ -94,6 +94,8 @@ from icefall.utils import (
|
|||||||
write_surt_error_stats,
|
write_surt_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"]
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -256,6 +258,13 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-masks",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""If true, save masks generated by unmixing module.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -319,6 +328,22 @@ def decode_one_batch(
|
|||||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||||
|
|
||||||
|
def _group_channels(hyps: List[str]) -> List[List[str]]:
|
||||||
|
"""
|
||||||
|
Currently we have a batch of size M*B, where M is the number of
|
||||||
|
channels and B is the batch size. We need to group the hypotheses
|
||||||
|
into B groups, each of which contains M hypotheses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
|
||||||
|
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
|
||||||
|
"""
|
||||||
|
assert len(hyps) == B * params.num_channels
|
||||||
|
out_hyps = []
|
||||||
|
for i in range(B):
|
||||||
|
out_hyps.append(hyps[i::B])
|
||||||
|
return out_hyps
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -331,7 +356,7 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp)
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
@ -339,7 +364,7 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -348,7 +373,7 @@ def decode_one_batch(
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp)
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -372,10 +397,10 @@ def decode_one_batch(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp))
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": _group_channels(hyps)}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
key = f"beam_{params.beam}_"
|
key = f"beam_{params.beam}_"
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
@ -386,9 +411,9 @@ def decode_one_batch(
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: _group_channels(hyps)}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -437,14 +462,8 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
# The dataloader returns text as a list of cuts, each of which is a list of channel
|
|
||||||
# text. We flatten this to a list where all channels are together, i.e., it looks like
|
|
||||||
# [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2].
|
|
||||||
texts = [val for tup in zip(*batch["text"]) for val in tup]
|
|
||||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||||
|
cuts_batch = batch["cuts"]
|
||||||
# Repeat cut_ids list N times, where N is the number of channels.
|
|
||||||
cut_ids = list(chain.from_iterable(repeat(cut_ids, params.num_channels)))
|
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -457,14 +476,19 @@ def decode_dataset(
|
|||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts), f"{len(hyps)} vs {len(texts)}"
|
for cut_id, hyp_words in zip(cut_ids, hyps):
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
# Reference is a list of supervision texts sorted by start time.
|
||||||
ref_words = ref_text.split()
|
ref_words = [
|
||||||
|
s.text.strip()
|
||||||
|
for s in sorted(
|
||||||
|
cuts_batch[cut_id].supervisions, key=lambda s: s.start
|
||||||
|
)
|
||||||
|
]
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(cut_ids)
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
@ -484,19 +508,7 @@ def save_results(
|
|||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
# Combine results by cut_id. This means that we combine different channels for
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
# ref and hyp of the same cut into list. Example:
|
|
||||||
# (cut1, ref1, hyp1), (cut1, ref2, hyp2), (cut2, ref3, hyp3) ->
|
|
||||||
# (cut1, [ref1, ref2], [hyp1, hyp2]), (cut2, [ref3], [hyp3])
|
|
||||||
# Also, each ref and hyp is currently a list of words. We join them into a string.
|
|
||||||
results_grouped = []
|
|
||||||
for cut_id, items in groupby(results, lambda x: x[0]):
|
|
||||||
items = list(items)
|
|
||||||
refs = [" ".join(item[1]) for item in items]
|
|
||||||
hyps = [" ".join(item[2]) for item in items]
|
|
||||||
results_grouped.append((cut_id, refs, hyps))
|
|
||||||
|
|
||||||
store_transcripts(filename=recog_path, texts=results_grouped)
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -506,7 +518,7 @@ def save_results(
|
|||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_surt_error_stats(
|
wer = write_surt_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_grouped, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
@ -529,6 +541,16 @@ def save_results(
|
|||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
def save_masks(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
masks: List[torch.Tensor],
|
||||||
|
):
|
||||||
|
masks_path = params.res_dir / f"masks-{test_set_name}.txt"
|
||||||
|
torch.save(masks, masks_path)
|
||||||
|
logging.info(f"The masks are stored in {masks_path}")
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
@ -703,15 +725,39 @@ def main():
|
|||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
librimix = LibrimixAsrDataModule(args)
|
librimix = LibrimixAsrDataModule(args)
|
||||||
|
|
||||||
dev_2spk_cuts = librimix.dev_cuts_2spk()
|
dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix").to_eager()
|
||||||
dev_2spk_dl = librimix.test_dataloaders(dev_2spk_cuts)
|
dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS]
|
||||||
train_2spk_cuts = librimix.train_cuts_2spk(sp=None)
|
test_cuts = librimix.libricss_cuts(split="test", type="ihm-mix").to_eager()
|
||||||
train_2spk_dl = librimix.test_dataloaders(train_2spk_cuts)
|
test_cuts_grouped = [
|
||||||
|
test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS
|
||||||
|
]
|
||||||
|
|
||||||
test_sets = ["dev_2spk", "train_2spk"]
|
for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS):
|
||||||
test_dl = [dev_2spk_dl, train_2spk_dl]
|
dev_dl = librimix.test_dataloaders(dev_set)
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=dev_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
word_table=word_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
)
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=f"dev_{ol}",
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if params.save_masks:
|
||||||
|
# save_masks(
|
||||||
|
# params=params,
|
||||||
|
# test_set_name=f"dev_{ol}",
|
||||||
|
# masks=masks,
|
||||||
|
# )
|
||||||
|
|
||||||
|
for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
|
||||||
|
test_dl = librimix.test_dataloaders(test_set)
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
@ -723,10 +769,17 @@ def main():
|
|||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
params=params,
|
params=params,
|
||||||
test_set_name=test_set,
|
test_set_name=f"test_{ol}",
|
||||||
results_dict=results_dict,
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if params.save_masks:
|
||||||
|
# save_masks(
|
||||||
|
# params=params,
|
||||||
|
# test_set_name=f"test_{ol}",
|
||||||
|
# masks=masks,
|
||||||
|
# )
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -236,7 +236,6 @@ class DPRNN(nn.Module):
|
|||||||
min_positive=0.45,
|
min_positive=0.45,
|
||||||
max_positive=0.55,
|
max_positive=0.55,
|
||||||
),
|
),
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
@ -276,6 +275,9 @@ class DPRNN(nn.Module):
|
|||||||
output = output.transpose(1, 2)
|
output = output.transpose(1, 2)
|
||||||
output = self.out_embed(output)
|
output = self.out_embed(output)
|
||||||
|
|
||||||
|
# Apply ReLU to the output
|
||||||
|
output = torch.relu(output)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -303,7 +303,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr", type=float, default=0.04, help="The base learning rate."
|
"--base-lr", type=float, default=0.004, help="The base learning rate."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -747,6 +747,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
@ -1136,9 +1137,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if checkpoints is None and params.encoder_init_ckpt is not None:
|
if checkpoints is None and params.encoder_init_ckpt is not None:
|
||||||
logging.info("Initializing encoder with checkpoint")
|
logging.info("Initializing encoder with checkpoint")
|
||||||
model.encoder.load_state_dict(
|
init_ckpt = torch.load(params.encoder_init_ckpt, map_location=device)
|
||||||
torch.load(params.encoder_init_ckpt, map_location=device)
|
model.load_state_dict(init_ckpt["model"], strict=False)
|
||||||
)
|
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
@ -1176,8 +1176,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_cuts = librimix.train_cuts(reverberated=False)
|
train_cuts = librimix.train_cuts(reverberated=False)
|
||||||
train_cuts_2spk = librimix.train_cuts_2spk(reverberated=False)
|
train_cuts_2spk = librimix.train_cuts_2spk(reverberated=False)
|
||||||
# train_cuts_1spk = librimix.train_cuts_1spk(sp=sp)
|
# dev_cuts = librimix.dev_cuts(reverberated=False)
|
||||||
dev_cuts = librimix.dev_cuts(reverberated=False)
|
dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix")
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
@ -1191,7 +1191,6 @@ def run(rank, world_size, args):
|
|||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
)
|
)
|
||||||
train_dl_2spk = librimix.train_dataloaders(train_cuts_2spk)
|
train_dl_2spk = librimix.train_dataloaders(train_cuts_2spk)
|
||||||
# train_dl = librimix.train_dataloaders(train_cuts_1spk)
|
|
||||||
valid_dl = librimix.valid_dataloaders(dev_cuts)
|
valid_dl = librimix.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
|
|||||||
@ -120,8 +120,9 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
|
|
||||||
sim_cmd="queue.pl --mem 16G -l 'num_proc=4,h_rt=600:00:00'"
|
sim_cmd="queue.pl --mem 16G -l 'num_proc=4,h_rt=600:00:00'"
|
||||||
|
|
||||||
# gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
|
gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
|
||||||
# grep -v "0L" | gzip -c > data/manifests/libricss-sdm_supervisions_all_no0L.jsonl.gz
|
grep -v "0L" | grep -v "OV10" | grep -v "OV20" |\
|
||||||
|
gzip -c > data/manifests/libricss-sdm_supervisions_all_v2.jsonl.gz
|
||||||
|
|
||||||
# 2-speaker anechoic
|
# 2-speaker anechoic
|
||||||
# log "Generating 2-speaker anechoic training set"
|
# log "Generating 2-speaker anechoic training set"
|
||||||
@ -152,7 +153,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
# data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
|
# data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
|
||||||
# data/manifests/libri-mix_cuts_train_2spk_rvb.jsonl.gz
|
# data/manifests/libri-mix_cuts_train_2spk_rvb.jsonl.gz
|
||||||
|
|
||||||
# Full training set (2,3,4 speakers) anechoic
|
# Full training set (2,3 speakers) anechoic
|
||||||
for part in train; do
|
for part in train; do
|
||||||
if [ $part == "dev" ]; then
|
if [ $part == "dev" ]; then
|
||||||
num_jobs=1
|
num_jobs=1
|
||||||
@ -162,11 +163,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
log "Generating anechoic ${part} set (full)"
|
log "Generating anechoic ${part} set (full)"
|
||||||
$sim_cmd exp/sim_${part}.log lhotse workflows simulate-meetings \
|
$sim_cmd exp/sim_${part}.log lhotse workflows simulate-meetings \
|
||||||
--method conversational \
|
--method conversational \
|
||||||
--fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_no0L.jsonl.gz \
|
|
||||||
--num-repeats 1 \
|
--num-repeats 1 \
|
||||||
--num-speakers-per-meeting 2,3,4 \
|
--same-spk-pause 0.5 \
|
||||||
--max-duration-per-speaker 20.0 \
|
--diff-spk-pause 0.5 \
|
||||||
--max-utterances-per-speaker 4 \
|
--diff-spk-overlap 2 \
|
||||||
|
--prob-diff-spk-overlap 0.75 \
|
||||||
|
--num-speakers-per-meeting 2,3 \
|
||||||
|
--max-duration-per-speaker 15.0 \
|
||||||
|
--max-utterances-per-speaker 3 \
|
||||||
--seed 1234 \
|
--seed 1234 \
|
||||||
--num-jobs ${num_jobs} \
|
--num-jobs ${num_jobs} \
|
||||||
data/manifests/librispeech_cuts_${part}_trimmed.jsonl.gz \
|
data/manifests/librispeech_cuts_${part}_trimmed.jsonl.gz \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user