Associate a cut with token alignment (without repeats) (#125)

* WIP: Associate a cut with token alignment (without repeats)

* Save framewise alignments with/without repeats.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2021-11-29 18:50:54 +08:00 committed by GitHub
parent 243fb9723c
commit ec591698b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 374 additions and 357 deletions

View File

@ -15,15 +15,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./conformer_ctc/ali.py \
--exp-dir ./conformer_ctc/exp \
--lang-dir ./data/lang_bpe_500 \
--epoch 20 \
--avg 10 \
--max-duration 300 \
--dataset train-clean-100 \
--out-dir data/ali
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import numpy as np
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse import CutSet
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -34,7 +48,6 @@ from icefall.utils import (
AttributeDict,
encode_supervisions,
get_alignments,
save_alignments,
setup_logger,
)
@ -75,10 +88,42 @@ def get_parser():
)
parser.add_argument(
"--ali-dir",
"--out-dir",
type=str,
default="data/ali_500",
help="The experiment dir",
required=True,
help="""Output directory.
It contains 3 generated files:
- labels_xxx.h5
- aux_labels_xxx.h5
- cuts_xxx.json.gz
where xxx is the value of `--dataset`. For instance, if
`--dataset` is `train-clean-100`, it will contain 3 files:
- `labels_train-clean-100.h5`
- `aux_labels_train-clean-100.h5`
- `cuts_train-clean-100.json.gz`
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
alignment. The difference is that labels_xxx.h5 contains repeats.
""",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="""The name of the dataset to compute alignments for.
Possible values are:
- test-clean.
- test-other
- train-clean-100
- train-clean-360
- train-other-500
- dev-clean
- dev-other
""",
)
return parser
@ -91,7 +136,9 @@ def get_params() -> AttributeDict:
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
# Set it to 0 since attention decoder
# is not used for computing alignments
"num_decoder_layers": 0,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
@ -105,9 +152,11 @@ def get_params() -> AttributeDict:
def compute_alignments(
model: torch.nn.Module,
dl: torch.utils.data.DataLoader,
labels_writer: FeaturesWriter,
aux_labels_writer: FeaturesWriter,
params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler,
) -> List[Tuple[str, List[int]]]:
) -> CutSet:
"""Compute the framewise alignments of a dataset.
Args:
@ -120,9 +169,10 @@ def compute_alignments(
graph_compiler:
It converts token IDs to decoding graphs.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- Framewise alignments (token IDs) after subsampling
Return a CutSet. Each cut has two custom fields: labels_alignment
and aux_labels_alignment, containing framewise alignments information.
Both are of type `lhotse.array.TemporalArray`. The difference between
the two alignments is that `labels_alignment` contain repeats.
"""
try:
num_batches = len(dl)
@ -131,7 +181,7 @@ def compute_alignments(
num_cuts = 0
device = graph_compiler.device
ans = []
cuts = []
for batch_idx, batch in enumerate(dl):
feature = batch["inputs"]
@ -140,11 +190,10 @@ def compute_alignments(
feature = feature.to(device)
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
cut_ids = []
for cut in supervisions["cut"]:
assert len(cut.supervisions) == 1
cut_ids.append(cut.id)
for cut in cut_list:
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
@ -156,7 +205,8 @@ def compute_alignments(
# In general, new2old is an identity map since lhotse sorts the returned
# cuts by duration in descending order
new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]
cut_list = [cut_list[i] for i in new2old]
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
@ -178,11 +228,32 @@ def compute_alignments(
use_double_scores=params.use_double_scores,
)
ali_ids = get_alignments(best_path)
assert len(ali_ids) == len(cut_ids)
ans += list(zip(cut_ids, ali_ids))
labels_ali = get_alignments(best_path, kind="labels")
aux_labels_ali = get_alignments(best_path, kind="aux_labels")
assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
for cut, labels, aux_labels in zip(
cut_list, labels_ali, aux_labels_ali
):
cut.labels_alignment = labels_writer.store_array(
key=cut.id,
value=np.asarray(labels, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
cut.aux_labels_alignment = aux_labels_writer.store_array(
key=cut.id,
value=np.asarray(aux_labels, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
num_cuts += len(ali_ids)
cuts += cut_list
num_cuts += len(cut_list)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
@ -191,7 +262,7 @@ def compute_alignments(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return ans
return CutSet.from_cuts(cuts)
@torch.no_grad()
@ -200,20 +271,35 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.return_cuts is True
assert args.concatenate_cuts is False
if args.full_libri is False:
print("Changing --full-libri to True")
args.full_libri = True
args.enable_spec_aug = False
args.enable_musan = False
args.return_cuts = True
args.concatenate_cuts = False
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/ali")
setup_logger(f"{params.exp_dir}/log-ali")
logging.info("Computing alignment - started")
logging.info(f"Computing alignments for {params.dataset} - started")
logging.info(params)
out_dir = Path(params.out_dir)
out_dir.mkdir(exist_ok=True)
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
for f in (
out_labels_ali_filename,
out_aux_labels_ali_filename,
out_manifest_filename,
):
if f.exists():
logging.info(f"{f} exists - skipping")
return
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
@ -221,6 +307,7 @@ def main():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
@ -240,9 +327,12 @@ def main():
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
load_checkpoint(
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
)
else:
start = params.epoch - params.avg + 1
filenames = []
@ -250,60 +340,55 @@ def main():
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
if params.dataset == "test-clean":
test_clean_cuts = librispeech.test_clean_cuts()
dl = librispeech.test_dataloaders(test_clean_cuts)
elif params.dataset == "test-other":
test_other_cuts = librispeech.test_other_cuts()
dl = librispeech.test_dataloaders(test_other_cuts)
elif params.dataset == "train-clean-100":
train_clean_100_cuts = librispeech.train_clean_100_cuts()
dl = librispeech.train_dataloaders(train_clean_100_cuts)
elif params.dataset == "train-clean-360":
train_clean_360_cuts = librispeech.train_clean_360_cuts()
dl = librispeech.train_dataloaders(train_clean_360_cuts)
elif params.dataset == "train-other-500":
train_other_500_cuts = librispeech.train_other_500_cuts()
dl = librispeech.train_dataloaders(train_other_500_cuts)
elif params.dataset == "dev-clean":
dev_clean_cuts = librispeech.dev_clean_cuts()
dl = librispeech.valid_dataloaders(dev_clean_cuts)
else:
assert params.dataset == "dev-other", f"{params.dataset}"
dev_other_cuts = librispeech.dev_other_cuts()
dl = librispeech.valid_dataloaders(dev_other_cuts)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
test_dl = librispeech.test_dataloaders() # a list
ali_dir = Path(params.ali_dir)
ali_dir.mkdir(exist_ok=True)
enabled_datasets = {
"test_clean": test_dl[0],
"test_other": test_dl[1],
"train-960": train_dl,
"valid": valid_dl,
}
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
# compute the alignments if you use --max-duration=500
#
# There are 960 * 3 = 2880 hours data and it takes only
# 3 hours 40 minutes to get the alignment.
# The RTF is roughly: 3.67 / 2880 = 0.0012743
#
# At the end, you would see
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
for name, dl in enabled_datasets.items():
logging.info(f"Processing {name}")
if name == "train-960":
logging.info(
f"It will take about 3 hours 40 minutes for {name}, "
"which contains 960 * 3 = 2880 hours of data"
)
alignments = compute_alignments(
logging.info(f"Processing {params.dataset}")
with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
cut_set = compute_alignments(
model=model,
dl=dl,
labels_writer=labels_writer,
aux_labels_writer=aux_labels_writer,
params=params,
graph_compiler=graph_compiler,
)
num_utt = len(alignments)
alignments = dict(alignments)
assert num_utt == len(alignments)
filename = ali_dir / f"{name}.pt"
save_alignments(
alignments=alignments,
subsampling_factor=params.subsampling_factor,
filename=filename,
)
cut_set.to_file(out_manifest_filename)
logging.info(
f"For dataset {name}, its alignments are saved to {filename}"
f"For dataset {params.dataset}, its alignments with repeats are "
f"saved to {out_labels_ali_filename}, the alignments without repeats "
f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
)

View File

@ -665,14 +665,17 @@ def main():
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -618,8 +618,16 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,

View File

@ -19,7 +19,6 @@ import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
@ -34,11 +33,10 @@ from lhotse.dataset import (
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule):
class LibriSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
@ -56,9 +54,11 @@ class LibriSpeechAsrDataModule(DataModule):
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
@ -74,7 +74,7 @@ class LibriSpeechAsrDataModule(DataModule):
"Otherwise, use 100h subset.",
)
group.add_argument(
"--feature-dir",
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
@ -154,17 +154,48 @@ class LibriSpeechAsrDataModule(DataModule):
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
logging.info("About to create train dataset")
transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
@ -179,15 +210,25 @@ class LibriSpeechAsrDataModule(DataModule):
)
] + transforms
input_transforms = [
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
@ -243,10 +284,7 @@ class LibriSpeechAsrDataModule(DataModule):
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
@ -285,25 +323,16 @@ class LibriSpeechAsrDataModule(DataModule):
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
cuts = self.test_cuts()
is_list = isinstance(cuts, list)
test_loaders = []
if not is_list:
cuts = [cuts]
for cuts_test in cuts:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
)
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
@ -312,48 +341,45 @@ class LibriSpeechAsrDataModule(DataModule):
sampler=sampler,
num_workers=self.args.num_workers,
)
test_loaders.append(test_dl)
if is_list:
return test_loaders
else:
return test_loaders[0]
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest(
self.args.manifest_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest(
self.args.manifest_dir / "cuts_train-clean-360.json.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest(
self.args.manifest_dir / "cuts_train-other-500.json.gz"
)
)
return cuts
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz")
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz")
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz")
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz")

View File

@ -474,14 +474,17 @@ def main():
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -532,8 +532,16 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)

View File

@ -85,6 +85,7 @@ def load_checkpoint(
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
strict: bool = False,
) -> Dict[str, Any]:
"""
TODO: document it
@ -101,9 +102,9 @@ def load_checkpoint(
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=False)
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=False)
model.load_state_dict(checkpoint["model"], strict=strict)
checkpoint.pop("model")

View File

@ -224,8 +224,8 @@ def get_texts(
return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs.
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
"""Extract labels or aux_labels from the best-path FSAs.
Args:
best_paths:
@ -233,17 +233,34 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
kind:
Possible values are: "labels" and "aux_labels". Caution: When it is
"labels", the resulting alignments contain repeats.
Returns:
Returns a list of lists of int, containing the token sequences we
decoded. For `ans[i]`, its length equals to the number of frames
after subsampling of the i-th utterance in the batch.
Example:
When `kind` is `labels`, one possible alignment example is (with
repeats)::
c c c blk a a blk blk t t t blk blk
If `kind` is `aux_labels`, the above example changes to::
c blk blk blk a blk blk blk t blk blk blk blk
"""
assert kind in ("labels", "aux_labels")
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
label_shape = best_paths.arcs.shape().remove_axis(1)
# label_shape has axes [fsa][arc]
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous())
labels = labels.remove_values_eq(-1)
return labels.tolist()
token_shape = best_paths.arcs.shape().remove_axis(1)
# token_shape has axes [fsa][arc]
tokens = k2.RaggedTensor(
token_shape, getattr(best_paths, kind).contiguous()
)
tokens = tokens.remove_values_eq(-1)
return tokens.tolist()
def save_alignments(

View File

@ -25,199 +25,65 @@
from pathlib import Path
import k2
import torch
from lhotse import load_manifest
from lhotse import CutSet, load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from torch.nn.utils.rnn import pad_sequence
from lhotse.dataset.collation import collate_custom_field
from torch.utils.data import DataLoader
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import get_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
lang_dir = egs_dir / "data/lang_bpe_500"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz"
def data_exists():
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
return cuts_json.exists() and lang_dir.exists()
def get_dataloader():
cuts_train = load_manifest(cut_json)
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
train_sampler = SingleCutSampler(
cuts_train,
max_duration=40,
cuts = load_manifest(cuts_json)
print(cuts[0])
cuts = cuts.with_features_path_prefix(egs_dir)
sampler = SingleCutSampler(
cuts,
max_duration=10,
shuffle=False,
)
train = K2SpeechRecognitionDataset(return_cuts=True)
dataset = K2SpeechRecognitionDataset(return_cuts=True)
train_dl = DataLoader(
train,
sampler=train_sampler,
dl = DataLoader(
dataset,
sampler=sampler,
batch_size=None,
num_workers=1,
persistent_workers=False,
)
return train_dl
def test_one_hot():
a = [1, 3, 2]
b = [1, 0, 4, 2]
c = [torch.tensor(a), torch.tensor(b)]
d = pad_sequence(c, batch_first=True, padding_value=0)
f = torch.nn.functional.one_hot(d, num_classes=5)
e = (1 - f) * -10.0
expected = torch.tensor(
[
[
[-10, 0, -10, -10, -10],
[-10, -10, -10, 0, -10],
[-10, -10, 0, -10, -10],
[0, -10, -10, -10, -10],
],
[
[-10, 0, -10, -10, -10],
[0, -10, -10, -10, -10],
[-10, -10, -10, -10, 0],
[-10, -10, 0, -10, -10],
],
]
).to(e.dtype)
assert torch.all(torch.eq(e, expected))
return dl
def test():
"""
The purpose of this test is to show that we can use pre-computed
alignments to construct a mask, adding it to a randomly generated
nnet_output, to decode the correct transcript from the resulting
nnet_output.
"""
if not data_exists():
return
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
dl = get_dataloader()
subsampling_factor, ali = load_alignments(ali_filename)
ali = convert_alignments_to_tensor(ali, device=device)
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
word_table = lexicon.word_table
HLG = k2.Fsa.from_dict(
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
)
for batch in dl:
features = batch["inputs"]
supervisions = batch["supervisions"]
N = features.shape[0]
T = features.shape[1] // subsampling_factor
nnet_output = (
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
.softmax(dim=-1)
.log()
cuts = supervisions["cut"]
labels_alignment, labels_alignment_length = collate_custom_field(
CutSet.from_cuts(cuts), "labels_alignment"
)
cut_ids = [cut.id for cut in supervisions["cut"]]
mask = lookup_alignments(
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
)
min_len = min(nnet_output.shape[1], mask.shape[1])
ali_model_scale = 0.8
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor,
supervisions["num_frames"] // subsampling_factor,
),
1,
).to(torch.int32)
aux_labels_alignment,
aux_labels_alignment_length,
) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment")
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=10000,
subsampling_factor=subsampling_factor,
)
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
hyps = [" ".join(s) for s in hyps]
print(hyps)
print(supervisions["text"])
print(labels_alignment)
print(aux_labels_alignment)
print(labels_alignment_length)
print(aux_labels_alignment_length)
break
def show_cut_ids():
# The purpose of this function is to check that
# for each utterance in the training set, there is
# a corresponding alignment.
#
# After generating a1.txt and b1.txt
# You can use
# wc -l a1.txt b1.txt
# which should show the same number of lines.
#
# cat a1.txt | sort | uniq > a11.txt
# cat b1.txt | sort | uniq > b11.txt
#
# md5sum a11.txt b11.txt
# which should show the identical hash
#
# diff a11.txt b11.txt
# should print nothing
subsampling_factor, ali = load_alignments(ali_filename)
with open("a1.txt", "w") as f:
for key in ali:
f.write(f"{key}\n")
# dl = get_dataloader()
cuts_train = (
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
)
ans = []
for cut in cuts_train:
ans.append(cut.id)
with open("b1.txt", "w") as f:
for line in ans:
f.write(f"{line}\n")
if __name__ == "__main__":
test()