Associate a cut with token alignment (without repeats) (#125)

* WIP: Associate a cut with token alignment (without repeats)

* Save framewise alignments with/without repeats.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2021-11-29 18:50:54 +08:00 committed by GitHub
parent 243fb9723c
commit ec591698b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 374 additions and 357 deletions

View File

@ -15,15 +15,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Usage:
./conformer_ctc/ali.py \
--exp-dir ./conformer_ctc/exp \
--lang-dir ./data/lang_bpe_500 \
--epoch 20 \
--avg 10 \
--max-duration 300 \
--dataset train-clean-100 \
--out-dir data/ali
"""
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple
import k2 import k2
import numpy as np
import torch import torch
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse import CutSet
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -34,7 +48,6 @@ from icefall.utils import (
AttributeDict, AttributeDict,
encode_supervisions, encode_supervisions,
get_alignments, get_alignments,
save_alignments,
setup_logger, setup_logger,
) )
@ -75,10 +88,42 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--ali-dir", "--out-dir",
type=str, type=str,
default="data/ali_500", required=True,
help="The experiment dir", help="""Output directory.
It contains 3 generated files:
- labels_xxx.h5
- aux_labels_xxx.h5
- cuts_xxx.json.gz
where xxx is the value of `--dataset`. For instance, if
`--dataset` is `train-clean-100`, it will contain 3 files:
- `labels_train-clean-100.h5`
- `aux_labels_train-clean-100.h5`
- `cuts_train-clean-100.json.gz`
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
alignment. The difference is that labels_xxx.h5 contains repeats.
""",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="""The name of the dataset to compute alignments for.
Possible values are:
- test-clean.
- test-other
- train-clean-100
- train-clean-360
- train-other-500
- dev-clean
- dev-other
""",
) )
return parser return parser
@ -91,7 +136,9 @@ def get_params() -> AttributeDict:
"nhead": 8, "nhead": 8,
"attention_dim": 512, "attention_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"num_decoder_layers": 6, # Set it to 0 since attention decoder
# is not used for computing alignments
"num_decoder_layers": 0,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"output_beam": 10, "output_beam": 10,
@ -105,9 +152,11 @@ def get_params() -> AttributeDict:
def compute_alignments( def compute_alignments(
model: torch.nn.Module, model: torch.nn.Module,
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
labels_writer: FeaturesWriter,
aux_labels_writer: FeaturesWriter,
params: AttributeDict, params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
) -> List[Tuple[str, List[int]]]: ) -> CutSet:
"""Compute the framewise alignments of a dataset. """Compute the framewise alignments of a dataset.
Args: Args:
@ -120,9 +169,10 @@ def compute_alignments(
graph_compiler: graph_compiler:
It converts token IDs to decoding graphs. It converts token IDs to decoding graphs.
Returns: Returns:
Return a list of tuples. Each tuple contains two entries: Return a CutSet. Each cut has two custom fields: labels_alignment
- Utterance ID and aux_labels_alignment, containing framewise alignments information.
- Framewise alignments (token IDs) after subsampling Both are of type `lhotse.array.TemporalArray`. The difference between
the two alignments is that `labels_alignment` contain repeats.
""" """
try: try:
num_batches = len(dl) num_batches = len(dl)
@ -131,7 +181,7 @@ def compute_alignments(
num_cuts = 0 num_cuts = 0
device = graph_compiler.device device = graph_compiler.device
ans = [] cuts = []
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
feature = batch["inputs"] feature = batch["inputs"]
@ -140,11 +190,10 @@ def compute_alignments(
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
cut_ids = [] for cut in cut_list:
for cut in supervisions["cut"]: assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
assert len(cut.supervisions) == 1
cut_ids.append(cut.id)
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is [N, T, C]
@ -156,7 +205,8 @@ def compute_alignments(
# In general, new2old is an identity map since lhotse sorts the returned # In general, new2old is an identity map since lhotse sorts the returned
# cuts by duration in descending order # cuts by duration in descending order
new2old = supervision_segments[:, 0].tolist() new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]
cut_list = [cut_list[i] for i in new2old]
token_ids = graph_compiler.texts_to_ids(texts) token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids) decoding_graph = graph_compiler.compile(token_ids)
@ -178,11 +228,32 @@ def compute_alignments(
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
) )
ali_ids = get_alignments(best_path) labels_ali = get_alignments(best_path, kind="labels")
assert len(ali_ids) == len(cut_ids) aux_labels_ali = get_alignments(best_path, kind="aux_labels")
ans += list(zip(cut_ids, ali_ids)) assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
for cut, labels, aux_labels in zip(
cut_list, labels_ali, aux_labels_ali
):
cut.labels_alignment = labels_writer.store_array(
key=cut.id,
value=np.asarray(labels, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
cut.aux_labels_alignment = aux_labels_writer.store_array(
key=cut.id,
value=np.asarray(aux_labels, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
num_cuts += len(ali_ids) cuts += cut_list
num_cuts += len(cut_list)
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
@ -191,7 +262,7 @@ def compute_alignments(
f"batch {batch_str}, cuts processed until now is {num_cuts}" f"batch {batch_str}, cuts processed until now is {num_cuts}"
) )
return ans return CutSet.from_cuts(cuts)
@torch.no_grad() @torch.no_grad()
@ -200,20 +271,35 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
assert args.return_cuts is True args.enable_spec_aug = False
assert args.concatenate_cuts is False args.enable_musan = False
if args.full_libri is False: args.return_cuts = True
print("Changing --full-libri to True") args.concatenate_cuts = False
args.full_libri = True
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/ali") setup_logger(f"{params.exp_dir}/log-ali")
logging.info("Computing alignment - started") logging.info(f"Computing alignments for {params.dataset} - started")
logging.info(params) logging.info(params)
out_dir = Path(params.out_dir)
out_dir.mkdir(exist_ok=True)
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
for f in (
out_labels_ali_filename,
out_aux_labels_ali_filename,
out_manifest_filename,
):
if f.exists():
logging.info(f"{f} exists - skipping")
return
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank num_classes = max_token_id + 1 # +1 for the blank
@ -221,6 +307,7 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler( graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
@ -240,9 +327,12 @@ def main():
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
model.to(device)
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
)
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
filenames = [] filenames = []
@ -250,60 +340,55 @@ def main():
if start >= 0: if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames)) model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval() model.eval()
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.dataset == "test-clean":
test_clean_cuts = librispeech.test_clean_cuts()
dl = librispeech.test_dataloaders(test_clean_cuts)
elif params.dataset == "test-other":
test_other_cuts = librispeech.test_other_cuts()
dl = librispeech.test_dataloaders(test_other_cuts)
elif params.dataset == "train-clean-100":
train_clean_100_cuts = librispeech.train_clean_100_cuts()
dl = librispeech.train_dataloaders(train_clean_100_cuts)
elif params.dataset == "train-clean-360":
train_clean_360_cuts = librispeech.train_clean_360_cuts()
dl = librispeech.train_dataloaders(train_clean_360_cuts)
elif params.dataset == "train-other-500":
train_other_500_cuts = librispeech.train_other_500_cuts()
dl = librispeech.train_dataloaders(train_other_500_cuts)
elif params.dataset == "dev-clean":
dev_clean_cuts = librispeech.dev_clean_cuts()
dl = librispeech.valid_dataloaders(dev_clean_cuts)
else:
assert params.dataset == "dev-other", f"{params.dataset}"
dev_other_cuts = librispeech.dev_other_cuts()
dl = librispeech.valid_dataloaders(dev_other_cuts)
train_dl = librispeech.train_dataloaders() logging.info(f"Processing {params.dataset}")
valid_dl = librispeech.valid_dataloaders() with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
test_dl = librispeech.test_dataloaders() # a list with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
cut_set = compute_alignments(
ali_dir = Path(params.ali_dir)
ali_dir.mkdir(exist_ok=True)
enabled_datasets = {
"test_clean": test_dl[0],
"test_other": test_dl[1],
"train-960": train_dl,
"valid": valid_dl,
}
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
# compute the alignments if you use --max-duration=500
#
# There are 960 * 3 = 2880 hours data and it takes only
# 3 hours 40 minutes to get the alignment.
# The RTF is roughly: 3.67 / 2880 = 0.0012743
#
# At the end, you would see
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
for name, dl in enabled_datasets.items():
logging.info(f"Processing {name}")
if name == "train-960":
logging.info(
f"It will take about 3 hours 40 minutes for {name}, "
"which contains 960 * 3 = 2880 hours of data"
)
alignments = compute_alignments(
model=model, model=model,
dl=dl, dl=dl,
labels_writer=labels_writer,
aux_labels_writer=aux_labels_writer,
params=params, params=params,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
) )
num_utt = len(alignments)
alignments = dict(alignments) cut_set.to_file(out_manifest_filename)
assert num_utt == len(alignments)
filename = ali_dir / f"{name}.pt"
save_alignments(
alignments=alignments,
subsampling_factor=params.subsampling_factor,
filename=filename,
)
logging.info( logging.info(
f"For dataset {name}, its alignments are saved to {filename}" f"For dataset {params.dataset}, its alignments with repeats are "
f"saved to {out_labels_ali_filename}, the alignments without repeats "
f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
) )

View File

@ -665,14 +665,17 @@ def main():
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip test_clean_cuts = librispeech.test_clean_cuts()
# it inside the for loop. That is, use test_other_cuts = librispeech.test_other_cuts()
#
# if test_set == 'test-clean': continue test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -618,8 +618,16 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,

View File

@ -19,7 +19,6 @@ import argparse
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
@ -34,11 +33,10 @@ from lhotse.dataset import (
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule): class LibriSpeechAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader, It assumes there is always one train and valid dataloader,
@ -56,9 +54,11 @@ class LibriSpeechAsrDataModule(DataModule):
This class should be derived for specific corpora used in ASR tasks. This class should be derived for specific corpora used in ASR tasks.
""" """
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description="These options are used for the preparation of "
@ -74,7 +74,7 @@ class LibriSpeechAsrDataModule(DataModule):
"Otherwise, use 100h subset.", "Otherwise, use 100h subset.",
) )
group.add_argument( group.add_argument(
"--feature-dir", "--manifest-dir",
type=Path, type=Path,
default=Path("data/fbank"), default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.", help="Path to directory with train/valid/test cuts.",
@ -154,17 +154,48 @@ class LibriSpeechAsrDataModule(DataModule):
"collect the batches.", "collect the batches.",
) )
def train_dataloaders(self) -> DataLoader: group.add_argument(
logging.info("About to get train cuts") "--enable-spec-aug",
cuts_train = self.train_cuts() type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
logging.info("About to create train dataset")
transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "
@ -179,15 +210,25 @@ class LibriSpeechAsrDataModule(DataModule):
) )
] + transforms ] + transforms
input_transforms = [ input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
input_transforms.append(
SpecAugment( SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2, num_frame_masks=2,
features_mask_size=27, features_mask_size=27,
num_feature_masks=2, num_feature_masks=2,
frames_mask_size=100, frames_mask_size=100,
) )
] )
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_transforms=input_transforms, input_transforms=input_transforms,
@ -243,10 +284,7 @@ class LibriSpeechAsrDataModule(DataModule):
return train_dl return train_dl
def valid_dataloaders(self) -> DataLoader: def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
transforms = [] transforms = []
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
transforms = [ transforms = [
@ -285,25 +323,16 @@ class LibriSpeechAsrDataModule(DataModule):
return valid_dl return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
cuts = self.test_cuts()
is_list = isinstance(cuts, list)
test_loaders = []
if not is_list:
cuts = [cuts]
for cuts_test in cuts:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
Fbank(FbankConfig(num_mel_bins=80))
)
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = BucketingSampler( sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False cuts, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
@ -312,48 +341,45 @@ class LibriSpeechAsrDataModule(DataModule):
sampler=sampler, sampler=sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
) )
test_loaders.append(test_dl) return test_dl
if is_list:
return test_loaders
else:
return test_loaders[0]
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train-clean-100 cuts")
cuts_train = load_manifest( return load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz" self.args.manifest_dir / "cuts_train-clean-100.json.gz"
) )
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get train-clean-360 cuts")
cuts_valid = load_manifest( return load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz" self.args.manifest_dir / "cuts_train-clean-360.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") )
return cuts_valid
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def train_other_500_cuts(self) -> CutSet:
test_sets = ["test-clean", "test-other"] logging.info("About to get train-other-500 cuts")
cuts = [] return load_manifest(
for test_set in test_sets: self.args.manifest_dir / "cuts_train-other-500.json.gz"
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
) )
)
return cuts @lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz")
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz")
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz")
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz")

View File

@ -474,14 +474,17 @@ def main():
model.eval() model.eval()
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip test_clean_cuts = librispeech.test_clean_cuts()
# it inside the for loop. That is, use test_other_cuts = librispeech.test_other_cuts()
#
# if test_set == 'test-clean': continue test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -532,8 +532,16 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)

View File

@ -85,6 +85,7 @@ def load_checkpoint(
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None, scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional[GradScaler] = None,
strict: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
TODO: document it TODO: document it
@ -101,9 +102,9 @@ def load_checkpoint(
src_key = "{}.{}".format("module", key) src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key) dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0 assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=False) model.load_state_dict(dst_state_dict, strict=strict)
else: else:
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=strict)
checkpoint.pop("model") checkpoint.pop("model")

View File

@ -224,8 +224,8 @@ def get_texts(
return aux_labels.tolist() return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs. """Extract labels or aux_labels from the best-path FSAs.
Args: Args:
best_paths: best_paths:
@ -233,17 +233,34 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
containing multiple FSAs, which is expected to be the result containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't of k2.shortest_path (otherwise the returned values won't
be meaningful). be meaningful).
kind:
Possible values are: "labels" and "aux_labels". Caution: When it is
"labels", the resulting alignments contain repeats.
Returns: Returns:
Returns a list of lists of int, containing the token sequences we Returns a list of lists of int, containing the token sequences we
decoded. For `ans[i]`, its length equals to the number of frames decoded. For `ans[i]`, its length equals to the number of frames
after subsampling of the i-th utterance in the batch. after subsampling of the i-th utterance in the batch.
Example:
When `kind` is `labels`, one possible alignment example is (with
repeats)::
c c c blk a a blk blk t t t blk blk
If `kind` is `aux_labels`, the above example changes to::
c blk blk blk a blk blk blk t blk blk blk blk
""" """
assert kind in ("labels", "aux_labels")
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
label_shape = best_paths.arcs.shape().remove_axis(1) token_shape = best_paths.arcs.shape().remove_axis(1)
# label_shape has axes [fsa][arc] # token_shape has axes [fsa][arc]
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous()) tokens = k2.RaggedTensor(
labels = labels.remove_values_eq(-1) token_shape, getattr(best_paths, kind).contiguous()
return labels.tolist() )
tokens = tokens.remove_values_eq(-1)
return tokens.tolist()
def save_alignments( def save_alignments(

View File

@ -25,199 +25,65 @@
from pathlib import Path from pathlib import Path
import k2 from lhotse import CutSet, load_manifest
import torch
from lhotse import load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from torch.nn.utils.rnn import pad_sequence from lhotse.dataset.collation import collate_custom_field
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import get_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent ICEFALL_DIR = Path(__file__).resolve().parent.parent
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR" egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
lang_dir = egs_dir / "data/lang_bpe_500" lang_dir = egs_dir / "data/lang_bpe_500"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz" cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
def data_exists(): def data_exists():
return ali_filename.exists() and cut_json.exists() and lang_dir.exists() return cuts_json.exists() and lang_dir.exists()
def get_dataloader(): def get_dataloader():
cuts_train = load_manifest(cut_json) cuts = load_manifest(cuts_json)
cuts_train = cuts_train.with_features_path_prefix(egs_dir) print(cuts[0])
train_sampler = SingleCutSampler( cuts = cuts.with_features_path_prefix(egs_dir)
cuts_train, sampler = SingleCutSampler(
max_duration=40, cuts,
max_duration=10,
shuffle=False, shuffle=False,
) )
train = K2SpeechRecognitionDataset(return_cuts=True) dataset = K2SpeechRecognitionDataset(return_cuts=True)
train_dl = DataLoader( dl = DataLoader(
train, dataset,
sampler=train_sampler, sampler=sampler,
batch_size=None, batch_size=None,
num_workers=1, num_workers=1,
persistent_workers=False, persistent_workers=False,
) )
return train_dl return dl
def test_one_hot():
a = [1, 3, 2]
b = [1, 0, 4, 2]
c = [torch.tensor(a), torch.tensor(b)]
d = pad_sequence(c, batch_first=True, padding_value=0)
f = torch.nn.functional.one_hot(d, num_classes=5)
e = (1 - f) * -10.0
expected = torch.tensor(
[
[
[-10, 0, -10, -10, -10],
[-10, -10, -10, 0, -10],
[-10, -10, 0, -10, -10],
[0, -10, -10, -10, -10],
],
[
[-10, 0, -10, -10, -10],
[0, -10, -10, -10, -10],
[-10, -10, -10, -10, 0],
[-10, -10, 0, -10, -10],
],
]
).to(e.dtype)
assert torch.all(torch.eq(e, expected))
def test(): def test():
"""
The purpose of this test is to show that we can use pre-computed
alignments to construct a mask, adding it to a randomly generated
nnet_output, to decode the correct transcript from the resulting
nnet_output.
"""
if not data_exists(): if not data_exists():
return return
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
dl = get_dataloader() dl = get_dataloader()
subsampling_factor, ali = load_alignments(ali_filename)
ali = convert_alignments_to_tensor(ali, device=device)
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
word_table = lexicon.word_table
HLG = k2.Fsa.from_dict(
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
)
for batch in dl: for batch in dl:
features = batch["inputs"]
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
N = features.shape[0] cuts = supervisions["cut"]
T = features.shape[1] // subsampling_factor labels_alignment, labels_alignment_length = collate_custom_field(
nnet_output = ( CutSet.from_cuts(cuts), "labels_alignment"
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
.softmax(dim=-1)
.log()
) )
cut_ids = [cut.id for cut in supervisions["cut"]]
mask = lookup_alignments(
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
)
min_len = min(nnet_output.shape[1], mask.shape[1])
ali_model_scale = 0.8
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], aux_labels_alignment,
supervisions["start_frame"] // subsampling_factor, aux_labels_alignment_length,
supervisions["num_frames"] // subsampling_factor, ) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment")
),
1,
).to(torch.int32)
lattice = get_lattice( print(labels_alignment)
nnet_output=nnet_output, print(aux_labels_alignment)
decoding_graph=HLG, print(labels_alignment_length)
supervision_segments=supervision_segments, print(aux_labels_alignment_length)
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=10000,
subsampling_factor=subsampling_factor,
)
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
hyps = [" ".join(s) for s in hyps]
print(hyps)
print(supervisions["text"])
break break
def show_cut_ids():
# The purpose of this function is to check that
# for each utterance in the training set, there is
# a corresponding alignment.
#
# After generating a1.txt and b1.txt
# You can use
# wc -l a1.txt b1.txt
# which should show the same number of lines.
#
# cat a1.txt | sort | uniq > a11.txt
# cat b1.txt | sort | uniq > b11.txt
#
# md5sum a11.txt b11.txt
# which should show the identical hash
#
# diff a11.txt b11.txt
# should print nothing
subsampling_factor, ali = load_alignments(ali_filename)
with open("a1.txt", "w") as f:
for key in ali:
f.write(f"{key}\n")
# dl = get_dataloader()
cuts_train = (
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
)
ans = []
for cut in cuts_train:
ans.append(cut.id)
with open("b1.txt", "w") as f:
for line in ans:
f.write(f"{line}\n")
if __name__ == "__main__": if __name__ == "__main__":
test() test()