WIP: Associate a cut with token alignment (without repeats)

This commit is contained in:
Fangjun Kuang 2021-11-18 17:35:08 +08:00
parent 30c43b7f69
commit 62ada37d4e
9 changed files with 315 additions and 332 deletions

View File

@ -15,15 +15,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./conformer_ctc/ali.py \
--exp-dir ./conformer_ctc/exp \
--lang-dir ./data/lang_bpe_500 \
--epoch 20 \
--avg 10 \
--max-duration 300 \
--dataset train-clean-100 \
--out-dir data/token-ali
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import numpy as np
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse import CutSet
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -75,10 +90,37 @@ def get_parser():
)
parser.add_argument(
"--ali-dir",
"--out-dir",
type=str,
default="data/ali_500",
help="The experiment dir",
required=True,
help="""Output directory.
It contains the following generated files:
- xxx.h5
- cuts_xxx.json.gz
where xxx is the value of `--dataset`. For instance, if
`--dataset` is `train-clean-100`, it will contain two files:
- `train-clean-100.h5`
- `cuts_train-clean-100.json.gz`
""",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="""The name of the dataset to compute alignments for.
Possible values are:
- test-clean.
- test-other
- train-clean-100
- train-clean-360
- train-other-500
- dev-clean
- dev-other
""",
)
return parser
@ -91,7 +133,9 @@ def get_params() -> AttributeDict:
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
# Set it to 0 since attention decoder
# is not used for computing alignments
"num_decoder_layers": 0,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
@ -105,9 +149,10 @@ def get_params() -> AttributeDict:
def compute_alignments(
model: torch.nn.Module,
dl: torch.utils.data.DataLoader,
writer: FeaturesWriter,
params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler,
) -> List[Tuple[str, List[int]]]:
) -> CutSet:
"""Compute the framewise alignments of a dataset.
Args:
@ -120,9 +165,8 @@ def compute_alignments(
graph_compiler:
It converts token IDs to decoding graphs.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- Framewise alignments (token IDs) after subsampling
Return a CutSet. Each cut has a custom field `token_alignment`
of type `lhotse.array.TemporalArray`.
"""
try:
num_batches = len(dl)
@ -131,7 +175,7 @@ def compute_alignments(
num_cuts = 0
device = graph_compiler.device
ans = []
cuts = []
for batch_idx, batch in enumerate(dl):
feature = batch["inputs"]
@ -140,11 +184,10 @@ def compute_alignments(
feature = feature.to(device)
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
cut_ids = []
for cut in supervisions["cut"]:
assert len(cut.supervisions) == 1
cut_ids.append(cut.id)
for cut in cut_list:
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
@ -156,10 +199,12 @@ def compute_alignments(
# In general, new2old is an identity map since lhotse sorts the returned
# cuts by duration in descending order
new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]
cut_list = [cut_list[i] for i in new2old]
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
decoding_graph.tokens = decoding_graph.aux_labels.clone()
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
@ -179,8 +224,18 @@ def compute_alignments(
)
ali_ids = get_alignments(best_path)
assert len(ali_ids) == len(cut_ids)
ans += list(zip(cut_ids, ali_ids))
assert len(ali_ids) == len(cut_list)
for cut, ali in zip(cut_list, ali_ids):
cut.token_alignment = writer.store_array(
key=cut.id,
value=np.asarray(ali, dtype=np.int32),
frame_shift=0.04, # frame shift is 0.01s, subsampling_factor is 4
temporal_dim=0,
start=0,
)
cuts += cut_list
num_cuts += len(ali_ids)
@ -191,7 +246,7 @@ def compute_alignments(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return ans
return CutSet.from_cuts(cuts)
@torch.no_grad()
@ -200,20 +255,32 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.return_cuts is True
assert args.concatenate_cuts is False
if args.full_libri is False:
print("Changing --full-libri to True")
args.full_libri = True
args.enable_spec_aug = False
args.enable_musan = False
args.return_cuts = True
args.concatenate_cuts = False
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/ali")
setup_logger(f"{params.exp_dir}/log-ali")
logging.info("Computing alignment - started")
logging.info(f"Computing alignments for {params.dataset} - started")
logging.info(params)
out_dir = Path(params.out_dir)
out_dir.mkdir(exist_ok=True)
out_ali_filename = out_dir / f"{params.dataset}.h5"
if out_ali_filename.exists():
logging.info(f"{out_ali_filename} exists - skipping")
return
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
if out_manifest_filename.exists():
logging.info(f"{out_manifest_filename} exists - skipping")
return
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
@ -221,6 +288,7 @@ def main():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
@ -240,9 +308,12 @@ def main():
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
load_checkpoint(
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
)
else:
start = params.epoch - params.avg + 1
filenames = []
@ -250,60 +321,52 @@ def main():
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
if params.dataset == "test-clean":
test_clean_cuts = librispeech.test_clean_cuts()
dl = librispeech.test_dataloaders(test_clean_cuts)
elif params.dataset == "test-other":
test_other_cuts = librispeech.test_other_cuts()
dl = librispeech.test_dataloaders(test_other_cuts)
elif params.dataset == "train-clean-100":
train_clean_100_cuts = librispeech.train_clean_100_cuts()
dl = librispeech.train_dataloaders(train_clean_100_cuts)
elif params.dataset == "train-clean-360":
train_clean_360_cuts = librispeech.train_clean_360_cuts()
dl = librispeech.train_dataloaders(train_clean_360_cuts)
elif params.dataset == "train-other-500":
train_other_500_cuts = librispeech.train_other_500_cuts()
dl = librispeech.train_dataloaders(train_other_500_cuts)
elif params.dataset == "dev-clean":
dev_clean_cuts = librispeech.dev_clean_cuts()
dl = librispeech.valid_dataloaders(dev_clean_cuts)
else:
assert params.dataset == "dev-other", f"{params.dataset}"
dev_other_cuts = librispeech.dev_other_cuts()
dl = librispeech.valid_dataloaders(dev_other_cuts)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
test_dl = librispeech.test_dataloaders() # a list
ali_dir = Path(params.ali_dir)
ali_dir.mkdir(exist_ok=True)
enabled_datasets = {
"test_clean": test_dl[0],
"test_other": test_dl[1],
"train-960": train_dl,
"valid": valid_dl,
}
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
# compute the alignments if you use --max-duration=500
#
# There are 960 * 3 = 2880 hours data and it takes only
# 3 hours 40 minutes to get the alignment.
# The RTF is roughly: 3.67 / 2880 = 0.0012743
#
# At the end, you would see
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
for name, dl in enabled_datasets.items():
logging.info(f"Processing {name}")
if name == "train-960":
logging.info(
f"It will take about 3 hours 40 minutes for {name}, "
"which contains 960 * 3 = 2880 hours of data"
)
alignments = compute_alignments(
logging.info(f"Processing {params.dataset}")
with NumpyHdf5Writer(out_ali_filename) as writer:
cut_set = compute_alignments(
model=model,
dl=dl,
writer=writer,
params=params,
graph_compiler=graph_compiler,
)
num_utt = len(alignments)
alignments = dict(alignments)
assert num_utt == len(alignments)
filename = ali_dir / f"{name}.pt"
save_alignments(
alignments=alignments,
subsampling_factor=params.subsampling_factor,
filename=filename,
)
cut_set.to_json(out_manifest_filename)
logging.info(
f"For dataset {name}, its alignments are saved to {filename}"
f"For dataset {params.dataset}, its alignments are "
f"saved to {out_ali_filename} and the cut manifest file "
f"is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
)

View File

@ -684,14 +684,17 @@ def main():
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -614,8 +614,16 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,

View File

@ -34,11 +34,10 @@ from lhotse.dataset import (
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule):
class LibriSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
@ -56,9 +55,11 @@ class LibriSpeechAsrDataModule(DataModule):
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
@ -74,7 +75,7 @@ class LibriSpeechAsrDataModule(DataModule):
"Otherwise, use 100h subset.",
)
group.add_argument(
"--feature-dir",
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
@ -154,17 +155,48 @@ class LibriSpeechAsrDataModule(DataModule):
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
logging.info("About to create train dataset")
transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
@ -179,15 +211,25 @@ class LibriSpeechAsrDataModule(DataModule):
)
] + transforms
input_transforms = [
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
@ -243,10 +285,7 @@ class LibriSpeechAsrDataModule(DataModule):
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
@ -285,25 +324,16 @@ class LibriSpeechAsrDataModule(DataModule):
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
cuts = self.test_cuts()
is_list = isinstance(cuts, list)
test_loaders = []
if not is_list:
cuts = [cuts]
for cuts_test in cuts:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
)
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
@ -312,48 +342,45 @@ class LibriSpeechAsrDataModule(DataModule):
sampler=sampler,
num_workers=self.args.num_workers,
)
test_loaders.append(test_dl)
if is_list:
return test_loaders
else:
return test_loaders[0]
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest(
self.args.manifest_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest(
self.args.manifest_dir / "cuts_train-clean-360.json.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest(
self.args.manifest_dir / "cuts_train-other-500.json.gz"
)
)
return cuts
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz")
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz")
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz")
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz")

View File

@ -474,14 +474,17 @@ def main():
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -532,8 +532,16 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)

View File

@ -85,6 +85,7 @@ def load_checkpoint(
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
strict: bool = False,
) -> Dict[str, Any]:
"""
TODO: document it
@ -101,9 +102,9 @@ def load_checkpoint(
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=False)
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=False)
model.load_state_dict(checkpoint["model"], strict=strict)
checkpoint.pop("model")

View File

@ -306,7 +306,10 @@ def get_texts(
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs.
"""Extract the token IDs (from best_paths.tokens) from the best-path FSAs.
Caution:
There are no repeats in `best_paths.tokens`.
Args:
best_paths:
@ -320,11 +323,11 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
after subsampling of the i-th utterance in the batch.
"""
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
label_shape = best_paths.arcs.shape().remove_axis(1)
# label_shape has axes [fsa][arc]
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous())
labels = labels.remove_values_eq(-1)
return labels.tolist()
token_shape = best_paths.arcs.shape().remove_axis(1)
# token_shape has axes [fsa][arc]
tokens = k2.RaggedTensor(token_shape, best_paths.tokens)
tokens = tokens.remove_values_eq(-1)
return tokens.tolist()
def save_alignments(

View File

@ -44,180 +44,47 @@ from icefall.utils import get_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
lang_dir = egs_dir / "data/lang_bpe_500"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
cuts_json = egs_dir / "data/token_ali/cuts_test-clean.json.gz"
def data_exists():
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
return cuts_json.exists() and lang_dir.exists()
def get_dataloader():
cuts_train = load_manifest(cut_json)
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
train_sampler = SingleCutSampler(
cuts_train,
cuts = load_manifest(cuts_json)
cuts = cuts.with_features_path_prefix(egs_dir)
sampler = SingleCutSampler(
cuts,
max_duration=40,
shuffle=False,
)
train = K2SpeechRecognitionDataset(return_cuts=True)
dataset = K2SpeechRecognitionDataset(return_cuts=True)
train_dl = DataLoader(
train,
sampler=train_sampler,
dl = DataLoader(
dataset,
sampler=sampler,
batch_size=None,
num_workers=1,
persistent_workers=False,
)
return train_dl
def test_one_hot():
a = [1, 3, 2]
b = [1, 0, 4, 2]
c = [torch.tensor(a), torch.tensor(b)]
d = pad_sequence(c, batch_first=True, padding_value=0)
f = torch.nn.functional.one_hot(d, num_classes=5)
e = (1 - f) * -10.0
expected = torch.tensor(
[
[
[-10, 0, -10, -10, -10],
[-10, -10, -10, 0, -10],
[-10, -10, 0, -10, -10],
[0, -10, -10, -10, -10],
],
[
[-10, 0, -10, -10, -10],
[0, -10, -10, -10, -10],
[-10, -10, -10, -10, 0],
[-10, -10, 0, -10, -10],
],
]
).to(e.dtype)
assert torch.all(torch.eq(e, expected))
return dl
def test():
"""
The purpose of this test is to show that we can use pre-computed
alignments to construct a mask, adding it to a randomly generated
nnet_output, to decode the correct transcript from the resulting
nnet_output.
"""
if not data_exists():
return
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
dl = get_dataloader()
subsampling_factor, ali = load_alignments(ali_filename)
ali = convert_alignments_to_tensor(ali, device=device)
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
word_table = lexicon.word_table
HLG = k2.Fsa.from_dict(
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
)
for batch in dl:
features = batch["inputs"]
supervisions = batch["supervisions"]
N = features.shape[0]
T = features.shape[1] // subsampling_factor
nnet_output = (
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
.softmax(dim=-1)
.log()
)
cut_ids = [cut.id for cut in supervisions["cut"]]
mask = lookup_alignments(
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
)
min_len = min(nnet_output.shape[1], mask.shape[1])
ali_model_scale = 0.8
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor,
supervisions["num_frames"] // subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=10000,
subsampling_factor=subsampling_factor,
)
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
hyps = [" ".join(s) for s in hyps]
print(hyps)
print(supervisions["text"])
cuts = supervisions["cut"]
print(cuts)
break
def show_cut_ids():
# The purpose of this function is to check that
# for each utterance in the training set, there is
# a corresponding alignment.
#
# After generating a1.txt and b1.txt
# You can use
# wc -l a1.txt b1.txt
# which should show the same number of lines.
#
# cat a1.txt | sort | uniq > a11.txt
# cat b1.txt | sort | uniq > b11.txt
#
# md5sum a11.txt b11.txt
# which should show the identical hash
#
# diff a11.txt b11.txt
# should print nothing
subsampling_factor, ali = load_alignments(ali_filename)
with open("a1.txt", "w") as f:
for key in ali:
f.write(f"{key}\n")
# dl = get_dataloader()
cuts_train = (
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
)
ans = []
for cut in cuts_train:
ans.append(cut.id)
with open("b1.txt", "w") as f:
for line in ans:
f.write(f"{line}\n")
if __name__ == "__main__":
test()