mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Associate a cut with token alignment (without repeats) (#125)
* WIP: Associate a cut with token alignment (without repeats) * Save framewise alignments with/without repeats. * Minor fixes.
This commit is contained in:
parent
243fb9723c
commit
ec591698b0
@ -15,15 +15,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./conformer_ctc/ali.py \
|
||||
--exp-dir ./conformer_ctc/exp \
|
||||
--lang-dir ./data/lang_bpe_500 \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--max-duration 300 \
|
||||
--dataset train-clean-100 \
|
||||
--out-dir data/ali
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse import CutSet
|
||||
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -34,7 +48,6 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
get_alignments,
|
||||
save_alignments,
|
||||
setup_logger,
|
||||
)
|
||||
|
||||
@ -75,10 +88,42 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ali-dir",
|
||||
"--out-dir",
|
||||
type=str,
|
||||
default="data/ali_500",
|
||||
help="The experiment dir",
|
||||
required=True,
|
||||
help="""Output directory.
|
||||
It contains 3 generated files:
|
||||
|
||||
- labels_xxx.h5
|
||||
- aux_labels_xxx.h5
|
||||
- cuts_xxx.json.gz
|
||||
|
||||
where xxx is the value of `--dataset`. For instance, if
|
||||
`--dataset` is `train-clean-100`, it will contain 3 files:
|
||||
|
||||
- `labels_train-clean-100.h5`
|
||||
- `aux_labels_train-clean-100.h5`
|
||||
- `cuts_train-clean-100.json.gz`
|
||||
|
||||
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
|
||||
alignment. The difference is that labels_xxx.h5 contains repeats.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""The name of the dataset to compute alignments for.
|
||||
Possible values are:
|
||||
- test-clean.
|
||||
- test-other
|
||||
- train-clean-100
|
||||
- train-clean-360
|
||||
- train-other-500
|
||||
- dev-clean
|
||||
- dev-other
|
||||
""",
|
||||
)
|
||||
return parser
|
||||
|
||||
@ -91,7 +136,9 @@ def get_params() -> AttributeDict:
|
||||
"nhead": 8,
|
||||
"attention_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
# Set it to 0 since attention decoder
|
||||
# is not used for computing alignments
|
||||
"num_decoder_layers": 0,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"output_beam": 10,
|
||||
@ -105,9 +152,11 @@ def get_params() -> AttributeDict:
|
||||
def compute_alignments(
|
||||
model: torch.nn.Module,
|
||||
dl: torch.utils.data.DataLoader,
|
||||
labels_writer: FeaturesWriter,
|
||||
aux_labels_writer: FeaturesWriter,
|
||||
params: AttributeDict,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
) -> List[Tuple[str, List[int]]]:
|
||||
) -> CutSet:
|
||||
"""Compute the framewise alignments of a dataset.
|
||||
|
||||
Args:
|
||||
@ -120,9 +169,10 @@ def compute_alignments(
|
||||
graph_compiler:
|
||||
It converts token IDs to decoding graphs.
|
||||
Returns:
|
||||
Return a list of tuples. Each tuple contains two entries:
|
||||
- Utterance ID
|
||||
- Framewise alignments (token IDs) after subsampling
|
||||
Return a CutSet. Each cut has two custom fields: labels_alignment
|
||||
and aux_labels_alignment, containing framewise alignments information.
|
||||
Both are of type `lhotse.array.TemporalArray`. The difference between
|
||||
the two alignments is that `labels_alignment` contain repeats.
|
||||
"""
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
@ -131,7 +181,7 @@ def compute_alignments(
|
||||
num_cuts = 0
|
||||
|
||||
device = graph_compiler.device
|
||||
ans = []
|
||||
cuts = []
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
feature = batch["inputs"]
|
||||
|
||||
@ -140,11 +190,10 @@ def compute_alignments(
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
cut_list = supervisions["cut"]
|
||||
|
||||
cut_ids = []
|
||||
for cut in supervisions["cut"]:
|
||||
assert len(cut.supervisions) == 1
|
||||
cut_ids.append(cut.id)
|
||||
for cut in cut_list:
|
||||
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
|
||||
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, T, C]
|
||||
@ -156,7 +205,8 @@ def compute_alignments(
|
||||
# In general, new2old is an identity map since lhotse sorts the returned
|
||||
# cuts by duration in descending order
|
||||
new2old = supervision_segments[:, 0].tolist()
|
||||
cut_ids = [cut_ids[i] for i in new2old]
|
||||
|
||||
cut_list = [cut_list[i] for i in new2old]
|
||||
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
@ -178,11 +228,32 @@ def compute_alignments(
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
ali_ids = get_alignments(best_path)
|
||||
assert len(ali_ids) == len(cut_ids)
|
||||
ans += list(zip(cut_ids, ali_ids))
|
||||
labels_ali = get_alignments(best_path, kind="labels")
|
||||
aux_labels_ali = get_alignments(best_path, kind="aux_labels")
|
||||
assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
|
||||
for cut, labels, aux_labels in zip(
|
||||
cut_list, labels_ali, aux_labels_ali
|
||||
):
|
||||
cut.labels_alignment = labels_writer.store_array(
|
||||
key=cut.id,
|
||||
value=np.asarray(labels, dtype=np.int32),
|
||||
# frame shift is 0.01s, subsampling_factor is 4
|
||||
frame_shift=0.04,
|
||||
temporal_dim=0,
|
||||
start=0,
|
||||
)
|
||||
cut.aux_labels_alignment = aux_labels_writer.store_array(
|
||||
key=cut.id,
|
||||
value=np.asarray(aux_labels, dtype=np.int32),
|
||||
# frame shift is 0.01s, subsampling_factor is 4
|
||||
frame_shift=0.04,
|
||||
temporal_dim=0,
|
||||
start=0,
|
||||
)
|
||||
|
||||
num_cuts += len(ali_ids)
|
||||
cuts += cut_list
|
||||
|
||||
num_cuts += len(cut_list)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
@ -191,7 +262,7 @@ def compute_alignments(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
return ans
|
||||
return CutSet.from_cuts(cuts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -200,20 +271,35 @@ def main():
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.return_cuts is True
|
||||
assert args.concatenate_cuts is False
|
||||
if args.full_libri is False:
|
||||
print("Changing --full-libri to True")
|
||||
args.full_libri = True
|
||||
args.enable_spec_aug = False
|
||||
args.enable_musan = False
|
||||
args.return_cuts = True
|
||||
args.concatenate_cuts = False
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/ali")
|
||||
setup_logger(f"{params.exp_dir}/log-ali")
|
||||
|
||||
logging.info("Computing alignment - started")
|
||||
logging.info(f"Computing alignments for {params.dataset} - started")
|
||||
logging.info(params)
|
||||
|
||||
out_dir = Path(params.out_dir)
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
|
||||
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
|
||||
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
|
||||
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
|
||||
|
||||
for f in (
|
||||
out_labels_ali_filename,
|
||||
out_aux_labels_ali_filename,
|
||||
out_manifest_filename,
|
||||
):
|
||||
if f.exists():
|
||||
logging.info(f"{f} exists - skipping")
|
||||
return
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
@ -221,6 +307,7 @@ def main():
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
@ -240,9 +327,12 @@ def main():
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
load_checkpoint(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
|
||||
)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
@ -250,60 +340,55 @@ def main():
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
if params.dataset == "test-clean":
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
elif params.dataset == "test-other":
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
elif params.dataset == "train-clean-100":
|
||||
train_clean_100_cuts = librispeech.train_clean_100_cuts()
|
||||
dl = librispeech.train_dataloaders(train_clean_100_cuts)
|
||||
elif params.dataset == "train-clean-360":
|
||||
train_clean_360_cuts = librispeech.train_clean_360_cuts()
|
||||
dl = librispeech.train_dataloaders(train_clean_360_cuts)
|
||||
elif params.dataset == "train-other-500":
|
||||
train_other_500_cuts = librispeech.train_other_500_cuts()
|
||||
dl = librispeech.train_dataloaders(train_other_500_cuts)
|
||||
elif params.dataset == "dev-clean":
|
||||
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||
dl = librispeech.valid_dataloaders(dev_clean_cuts)
|
||||
else:
|
||||
assert params.dataset == "dev-other", f"{params.dataset}"
|
||||
dev_other_cuts = librispeech.dev_other_cuts()
|
||||
dl = librispeech.valid_dataloaders(dev_other_cuts)
|
||||
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
test_dl = librispeech.test_dataloaders() # a list
|
||||
|
||||
ali_dir = Path(params.ali_dir)
|
||||
ali_dir.mkdir(exist_ok=True)
|
||||
|
||||
enabled_datasets = {
|
||||
"test_clean": test_dl[0],
|
||||
"test_other": test_dl[1],
|
||||
"train-960": train_dl,
|
||||
"valid": valid_dl,
|
||||
}
|
||||
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
|
||||
# compute the alignments if you use --max-duration=500
|
||||
#
|
||||
# There are 960 * 3 = 2880 hours data and it takes only
|
||||
# 3 hours 40 minutes to get the alignment.
|
||||
# The RTF is roughly: 3.67 / 2880 = 0.0012743
|
||||
#
|
||||
# At the end, you would see
|
||||
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
|
||||
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
|
||||
for name, dl in enabled_datasets.items():
|
||||
logging.info(f"Processing {name}")
|
||||
if name == "train-960":
|
||||
logging.info(
|
||||
f"It will take about 3 hours 40 minutes for {name}, "
|
||||
"which contains 960 * 3 = 2880 hours of data"
|
||||
)
|
||||
alignments = compute_alignments(
|
||||
logging.info(f"Processing {params.dataset}")
|
||||
with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
|
||||
with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
|
||||
cut_set = compute_alignments(
|
||||
model=model,
|
||||
dl=dl,
|
||||
labels_writer=labels_writer,
|
||||
aux_labels_writer=aux_labels_writer,
|
||||
params=params,
|
||||
graph_compiler=graph_compiler,
|
||||
)
|
||||
num_utt = len(alignments)
|
||||
alignments = dict(alignments)
|
||||
assert num_utt == len(alignments)
|
||||
filename = ali_dir / f"{name}.pt"
|
||||
save_alignments(
|
||||
alignments=alignments,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
cut_set.to_file(out_manifest_filename)
|
||||
|
||||
logging.info(
|
||||
f"For dataset {name}, its alignments are saved to {filename}"
|
||||
f"For dataset {params.dataset}, its alignments with repeats are "
|
||||
f"saved to {out_labels_ali_filename}, the alignments without repeats "
|
||||
f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
|
||||
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -665,14 +665,17 @@ def main():
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
# it inside the for loop. That is, use
|
||||
#
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
|
@ -618,8 +618,16 @@ def run(rank, world_size, args):
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
|
@ -19,7 +19,6 @@ import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
@ -34,11 +33,10 @@ from lhotse.dataset import (
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule(DataModule):
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
@ -56,9 +54,11 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
super().add_arguments(parser)
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
@ -74,7 +74,7 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
"Otherwise, use 100h subset.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--feature-dir",
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
@ -154,17 +154,48 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
def train_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = self.train_cuts()
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||
)
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
transforms = [
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
]
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
@ -179,15 +210,25 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = [
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=2,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
@ -243,10 +284,7 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = self.valid_cuts()
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
@ -285,25 +323,16 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
cuts = self.test_cuts()
|
||||
is_list = isinstance(cuts, list)
|
||||
test_loaders = []
|
||||
if not is_list:
|
||||
cuts = [cuts]
|
||||
|
||||
for cuts_test in cuts:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
)
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = BucketingSampler(
|
||||
cuts_test, max_duration=self.args.max_duration, shuffle=False
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
@ -312,48 +341,45 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
test_loaders.append(test_dl)
|
||||
|
||||
if is_list:
|
||||
return test_loaders
|
||||
else:
|
||||
return test_loaders[0]
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(
|
||||
self.args.feature_dir / "cuts_train-clean-100.json.gz"
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest(
|
||||
self.args.manifest_dir / "cuts_train-clean-100.json.gz"
|
||||
)
|
||||
if self.args.full_libri:
|
||||
cuts_train = (
|
||||
cuts_train
|
||||
+ load_manifest(
|
||||
self.args.feature_dir / "cuts_train-clean-360.json.gz"
|
||||
)
|
||||
+ load_manifest(
|
||||
self.args.feature_dir / "cuts_train-other-500.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(
|
||||
self.args.feature_dir / "cuts_dev-clean.json.gz"
|
||||
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
|
||||
return cuts_valid
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest(
|
||||
self.args.manifest_dir / "cuts_train-clean-360.json.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> List[CutSet]:
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
cuts = []
|
||||
for test_set in test_sets:
|
||||
logging.debug("About to get test cuts")
|
||||
cuts.append(
|
||||
load_manifest(
|
||||
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest(
|
||||
self.args.manifest_dir / "cuts_train-other-500.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz")
|
||||
|
@ -474,14 +474,17 @@ def main():
|
||||
model.eval()
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
# it inside the for loop. That is, use
|
||||
#
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
|
@ -532,8 +532,16 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
@ -85,6 +85,7 @@ def load_checkpoint(
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
strict: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
TODO: document it
|
||||
@ -101,9 +102,9 @@ def load_checkpoint(
|
||||
src_key = "{}.{}".format("module", key)
|
||||
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||
assert len(src_state_dict) == 0
|
||||
model.load_state_dict(dst_state_dict, strict=False)
|
||||
model.load_state_dict(dst_state_dict, strict=strict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=strict)
|
||||
|
||||
checkpoint.pop("model")
|
||||
|
||||
|
@ -224,8 +224,8 @@ def get_texts(
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs.
|
||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||
"""Extract labels or aux_labels from the best-path FSAs.
|
||||
|
||||
Args:
|
||||
best_paths:
|
||||
@ -233,17 +233,34 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
kind:
|
||||
Possible values are: "labels" and "aux_labels". Caution: When it is
|
||||
"labels", the resulting alignments contain repeats.
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the token sequences we
|
||||
decoded. For `ans[i]`, its length equals to the number of frames
|
||||
after subsampling of the i-th utterance in the batch.
|
||||
|
||||
Example:
|
||||
When `kind` is `labels`, one possible alignment example is (with
|
||||
repeats)::
|
||||
|
||||
c c c blk a a blk blk t t t blk blk
|
||||
|
||||
If `kind` is `aux_labels`, the above example changes to::
|
||||
|
||||
c blk blk blk a blk blk blk t blk blk blk blk
|
||||
|
||||
"""
|
||||
assert kind in ("labels", "aux_labels")
|
||||
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
|
||||
label_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
# label_shape has axes [fsa][arc]
|
||||
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous())
|
||||
labels = labels.remove_values_eq(-1)
|
||||
return labels.tolist()
|
||||
token_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
# token_shape has axes [fsa][arc]
|
||||
tokens = k2.RaggedTensor(
|
||||
token_shape, getattr(best_paths, kind).contiguous()
|
||||
)
|
||||
tokens = tokens.remove_values_eq(-1)
|
||||
return tokens.tolist()
|
||||
|
||||
|
||||
def save_alignments(
|
||||
|
184
test/test_ali.py
184
test/test_ali.py
@ -25,199 +25,65 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from lhotse import load_manifest
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import get_texts
|
||||
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
|
||||
lang_dir = egs_dir / "data/lang_bpe_500"
|
||||
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
|
||||
# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
|
||||
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
|
||||
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
|
||||
|
||||
cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
|
||||
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
|
||||
cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz"
|
||||
|
||||
|
||||
def data_exists():
|
||||
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
|
||||
return cuts_json.exists() and lang_dir.exists()
|
||||
|
||||
|
||||
def get_dataloader():
|
||||
cuts_train = load_manifest(cut_json)
|
||||
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=40,
|
||||
cuts = load_manifest(cuts_json)
|
||||
print(cuts[0])
|
||||
cuts = cuts.with_features_path_prefix(egs_dir)
|
||||
sampler = SingleCutSampler(
|
||||
cuts,
|
||||
max_duration=10,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
train = K2SpeechRecognitionDataset(return_cuts=True)
|
||||
dataset = K2SpeechRecognitionDataset(return_cuts=True)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
dl = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=None,
|
||||
num_workers=1,
|
||||
persistent_workers=False,
|
||||
)
|
||||
return train_dl
|
||||
|
||||
|
||||
def test_one_hot():
|
||||
a = [1, 3, 2]
|
||||
b = [1, 0, 4, 2]
|
||||
c = [torch.tensor(a), torch.tensor(b)]
|
||||
d = pad_sequence(c, batch_first=True, padding_value=0)
|
||||
f = torch.nn.functional.one_hot(d, num_classes=5)
|
||||
e = (1 - f) * -10.0
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-10, 0, -10, -10, -10],
|
||||
[-10, -10, -10, 0, -10],
|
||||
[-10, -10, 0, -10, -10],
|
||||
[0, -10, -10, -10, -10],
|
||||
],
|
||||
[
|
||||
[-10, 0, -10, -10, -10],
|
||||
[0, -10, -10, -10, -10],
|
||||
[-10, -10, -10, -10, 0],
|
||||
[-10, -10, 0, -10, -10],
|
||||
],
|
||||
]
|
||||
).to(e.dtype)
|
||||
assert torch.all(torch.eq(e, expected))
|
||||
return dl
|
||||
|
||||
|
||||
def test():
|
||||
"""
|
||||
The purpose of this test is to show that we can use pre-computed
|
||||
alignments to construct a mask, adding it to a randomly generated
|
||||
nnet_output, to decode the correct transcript from the resulting
|
||||
nnet_output.
|
||||
"""
|
||||
if not data_exists():
|
||||
return
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
dl = get_dataloader()
|
||||
|
||||
subsampling_factor, ali = load_alignments(ali_filename)
|
||||
ali = convert_alignments_to_tensor(ali, device=device)
|
||||
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
word_table = lexicon.word_table
|
||||
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
|
||||
for batch in dl:
|
||||
features = batch["inputs"]
|
||||
supervisions = batch["supervisions"]
|
||||
N = features.shape[0]
|
||||
T = features.shape[1] // subsampling_factor
|
||||
nnet_output = (
|
||||
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
|
||||
.softmax(dim=-1)
|
||||
.log()
|
||||
cuts = supervisions["cut"]
|
||||
labels_alignment, labels_alignment_length = collate_custom_field(
|
||||
CutSet.from_cuts(cuts), "labels_alignment"
|
||||
)
|
||||
cut_ids = [cut.id for cut in supervisions["cut"]]
|
||||
mask = lookup_alignments(
|
||||
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
|
||||
)
|
||||
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||
ali_model_scale = 0.8
|
||||
|
||||
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // subsampling_factor,
|
||||
supervisions["num_frames"] // subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
aux_labels_alignment,
|
||||
aux_labels_alignment_length,
|
||||
) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment")
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=20,
|
||||
output_beam=8,
|
||||
min_active_states=30,
|
||||
max_active_states=10000,
|
||||
subsampling_factor=subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
hyps = [" ".join(s) for s in hyps]
|
||||
print(hyps)
|
||||
print(supervisions["text"])
|
||||
print(labels_alignment)
|
||||
print(aux_labels_alignment)
|
||||
print(labels_alignment_length)
|
||||
print(aux_labels_alignment_length)
|
||||
break
|
||||
|
||||
|
||||
def show_cut_ids():
|
||||
# The purpose of this function is to check that
|
||||
# for each utterance in the training set, there is
|
||||
# a corresponding alignment.
|
||||
#
|
||||
# After generating a1.txt and b1.txt
|
||||
# You can use
|
||||
# wc -l a1.txt b1.txt
|
||||
# which should show the same number of lines.
|
||||
#
|
||||
# cat a1.txt | sort | uniq > a11.txt
|
||||
# cat b1.txt | sort | uniq > b11.txt
|
||||
#
|
||||
# md5sum a11.txt b11.txt
|
||||
# which should show the identical hash
|
||||
#
|
||||
# diff a11.txt b11.txt
|
||||
# should print nothing
|
||||
|
||||
subsampling_factor, ali = load_alignments(ali_filename)
|
||||
with open("a1.txt", "w") as f:
|
||||
for key in ali:
|
||||
f.write(f"{key}\n")
|
||||
|
||||
# dl = get_dataloader()
|
||||
cuts_train = (
|
||||
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
|
||||
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
|
||||
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
|
||||
)
|
||||
|
||||
ans = []
|
||||
for cut in cuts_train:
|
||||
ans.append(cut.id)
|
||||
with open("b1.txt", "w") as f:
|
||||
for line in ans:
|
||||
f.write(f"{line}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
Loading…
x
Reference in New Issue
Block a user