WIP: Associate a cut with token alignment (without repeats)

This commit is contained in:
Fangjun Kuang 2021-11-18 17:35:08 +08:00
parent 30c43b7f69
commit 62ada37d4e
9 changed files with 315 additions and 332 deletions

View File

@ -15,15 +15,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Usage:
./conformer_ctc/ali.py \
--exp-dir ./conformer_ctc/exp \
--lang-dir ./data/lang_bpe_500 \
--epoch 20 \
--avg 10 \
--max-duration 300 \
--dataset train-clean-100 \
--out-dir data/token-ali
"""
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
import k2 import k2
import numpy as np
import torch import torch
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse import CutSet
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -75,10 +90,37 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--ali-dir", "--out-dir",
type=str, type=str,
default="data/ali_500", required=True,
help="The experiment dir", help="""Output directory.
It contains the following generated files:
- xxx.h5
- cuts_xxx.json.gz
where xxx is the value of `--dataset`. For instance, if
`--dataset` is `train-clean-100`, it will contain two files:
- `train-clean-100.h5`
- `cuts_train-clean-100.json.gz`
""",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="""The name of the dataset to compute alignments for.
Possible values are:
- test-clean.
- test-other
- train-clean-100
- train-clean-360
- train-other-500
- dev-clean
- dev-other
""",
) )
return parser return parser
@ -91,7 +133,9 @@ def get_params() -> AttributeDict:
"nhead": 8, "nhead": 8,
"attention_dim": 512, "attention_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"num_decoder_layers": 6, # Set it to 0 since attention decoder
# is not used for computing alignments
"num_decoder_layers": 0,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"output_beam": 10, "output_beam": 10,
@ -105,9 +149,10 @@ def get_params() -> AttributeDict:
def compute_alignments( def compute_alignments(
model: torch.nn.Module, model: torch.nn.Module,
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
writer: FeaturesWriter,
params: AttributeDict, params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
) -> List[Tuple[str, List[int]]]: ) -> CutSet:
"""Compute the framewise alignments of a dataset. """Compute the framewise alignments of a dataset.
Args: Args:
@ -120,9 +165,8 @@ def compute_alignments(
graph_compiler: graph_compiler:
It converts token IDs to decoding graphs. It converts token IDs to decoding graphs.
Returns: Returns:
Return a list of tuples. Each tuple contains two entries: Return a CutSet. Each cut has a custom field `token_alignment`
- Utterance ID of type `lhotse.array.TemporalArray`.
- Framewise alignments (token IDs) after subsampling
""" """
try: try:
num_batches = len(dl) num_batches = len(dl)
@ -131,7 +175,7 @@ def compute_alignments(
num_cuts = 0 num_cuts = 0
device = graph_compiler.device device = graph_compiler.device
ans = [] cuts = []
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
feature = batch["inputs"] feature = batch["inputs"]
@ -140,11 +184,10 @@ def compute_alignments(
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
cut_ids = [] for cut in cut_list:
for cut in supervisions["cut"]: assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
assert len(cut.supervisions) == 1
cut_ids.append(cut.id)
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is [N, T, C]
@ -156,10 +199,12 @@ def compute_alignments(
# In general, new2old is an identity map since lhotse sorts the returned # In general, new2old is an identity map since lhotse sorts the returned
# cuts by duration in descending order # cuts by duration in descending order
new2old = supervision_segments[:, 0].tolist() new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]
cut_list = [cut_list[i] for i in new2old]
token_ids = graph_compiler.texts_to_ids(texts) token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids) decoding_graph = graph_compiler.compile(token_ids)
decoding_graph.tokens = decoding_graph.aux_labels.clone()
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output,
@ -179,8 +224,18 @@ def compute_alignments(
) )
ali_ids = get_alignments(best_path) ali_ids = get_alignments(best_path)
assert len(ali_ids) == len(cut_ids) assert len(ali_ids) == len(cut_list)
ans += list(zip(cut_ids, ali_ids)) for cut, ali in zip(cut_list, ali_ids):
cut.token_alignment = writer.store_array(
key=cut.id,
value=np.asarray(ali, dtype=np.int32),
frame_shift=0.04, # frame shift is 0.01s, subsampling_factor is 4
temporal_dim=0,
start=0,
)
cuts += cut_list
num_cuts += len(ali_ids) num_cuts += len(ali_ids)
@ -191,7 +246,7 @@ def compute_alignments(
f"batch {batch_str}, cuts processed until now is {num_cuts}" f"batch {batch_str}, cuts processed until now is {num_cuts}"
) )
return ans return CutSet.from_cuts(cuts)
@torch.no_grad() @torch.no_grad()
@ -200,20 +255,32 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
assert args.return_cuts is True args.enable_spec_aug = False
assert args.concatenate_cuts is False args.enable_musan = False
if args.full_libri is False: args.return_cuts = True
print("Changing --full-libri to True") args.concatenate_cuts = False
args.full_libri = True
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/ali") setup_logger(f"{params.exp_dir}/log-ali")
logging.info("Computing alignment - started") logging.info(f"Computing alignments for {params.dataset} - started")
logging.info(params) logging.info(params)
out_dir = Path(params.out_dir)
out_dir.mkdir(exist_ok=True)
out_ali_filename = out_dir / f"{params.dataset}.h5"
if out_ali_filename.exists():
logging.info(f"{out_ali_filename} exists - skipping")
return
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
if out_manifest_filename.exists():
logging.info(f"{out_manifest_filename} exists - skipping")
return
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank num_classes = max_token_id + 1 # +1 for the blank
@ -221,6 +288,7 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler( graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
@ -240,9 +308,12 @@ def main():
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
model.to(device)
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
)
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
filenames = [] filenames = []
@ -250,61 +321,53 @@ def main():
if start >= 0: if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames)) model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval() model.eval()
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.dataset == "test-clean":
test_clean_cuts = librispeech.test_clean_cuts()
dl = librispeech.test_dataloaders(test_clean_cuts)
elif params.dataset == "test-other":
test_other_cuts = librispeech.test_other_cuts()
dl = librispeech.test_dataloaders(test_other_cuts)
elif params.dataset == "train-clean-100":
train_clean_100_cuts = librispeech.train_clean_100_cuts()
dl = librispeech.train_dataloaders(train_clean_100_cuts)
elif params.dataset == "train-clean-360":
train_clean_360_cuts = librispeech.train_clean_360_cuts()
dl = librispeech.train_dataloaders(train_clean_360_cuts)
elif params.dataset == "train-other-500":
train_other_500_cuts = librispeech.train_other_500_cuts()
dl = librispeech.train_dataloaders(train_other_500_cuts)
elif params.dataset == "dev-clean":
dev_clean_cuts = librispeech.dev_clean_cuts()
dl = librispeech.valid_dataloaders(dev_clean_cuts)
else:
assert params.dataset == "dev-other", f"{params.dataset}"
dev_other_cuts = librispeech.dev_other_cuts()
dl = librispeech.valid_dataloaders(dev_other_cuts)
train_dl = librispeech.train_dataloaders() logging.info(f"Processing {params.dataset}")
valid_dl = librispeech.valid_dataloaders() with NumpyHdf5Writer(out_ali_filename) as writer:
test_dl = librispeech.test_dataloaders() # a list cut_set = compute_alignments(
ali_dir = Path(params.ali_dir)
ali_dir.mkdir(exist_ok=True)
enabled_datasets = {
"test_clean": test_dl[0],
"test_other": test_dl[1],
"train-960": train_dl,
"valid": valid_dl,
}
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
# compute the alignments if you use --max-duration=500
#
# There are 960 * 3 = 2880 hours data and it takes only
# 3 hours 40 minutes to get the alignment.
# The RTF is roughly: 3.67 / 2880 = 0.0012743
#
# At the end, you would see
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
for name, dl in enabled_datasets.items():
logging.info(f"Processing {name}")
if name == "train-960":
logging.info(
f"It will take about 3 hours 40 minutes for {name}, "
"which contains 960 * 3 = 2880 hours of data"
)
alignments = compute_alignments(
model=model, model=model,
dl=dl, dl=dl,
writer=writer,
params=params, params=params,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
) )
num_utt = len(alignments)
alignments = dict(alignments) cut_set.to_json(out_manifest_filename)
assert num_utt == len(alignments)
filename = ali_dir / f"{name}.pt" logging.info(
save_alignments( f"For dataset {params.dataset}, its alignments are "
alignments=alignments, f"saved to {out_ali_filename} and the cut manifest file "
subsampling_factor=params.subsampling_factor, f"is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
filename=filename, )
)
logging.info(
f"For dataset {name}, its alignments are saved to {filename}"
)
torch.set_num_threads(1) torch.set_num_threads(1)

View File

@ -684,14 +684,17 @@ def main():
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip test_clean_cuts = librispeech.test_clean_cuts()
# it inside the for loop. That is, use test_other_cuts = librispeech.test_other_cuts()
#
# if test_set == 'test-clean': continue test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -614,8 +614,16 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,

View File

@ -34,11 +34,10 @@ from lhotse.dataset import (
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule): class LibriSpeechAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader, It assumes there is always one train and valid dataloader,
@ -56,9 +55,11 @@ class LibriSpeechAsrDataModule(DataModule):
This class should be derived for specific corpora used in ASR tasks. This class should be derived for specific corpora used in ASR tasks.
""" """
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description="These options are used for the preparation of "
@ -74,7 +75,7 @@ class LibriSpeechAsrDataModule(DataModule):
"Otherwise, use 100h subset.", "Otherwise, use 100h subset.",
) )
group.add_argument( group.add_argument(
"--feature-dir", "--manifest-dir",
type=Path, type=Path,
default=Path("data/fbank"), default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.", help="Path to directory with train/valid/test cuts.",
@ -154,17 +155,48 @@ class LibriSpeechAsrDataModule(DataModule):
"collect the batches.", "collect the batches.",
) )
def train_dataloaders(self) -> DataLoader: group.add_argument(
logging.info("About to get train cuts") "--enable-spec-aug",
cuts_train = self.train_cuts() type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
logging.info("About to create train dataset")
transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "
@ -179,15 +211,25 @@ class LibriSpeechAsrDataModule(DataModule):
) )
] + transforms ] + transforms
input_transforms = [ input_transforms = []
SpecAugment( if self.args.enable_spec_aug:
num_frame_masks=2, logging.info("Enable SpecAugment")
features_mask_size=27, logging.info(
num_feature_masks=2, f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
frames_mask_size=100,
) )
] input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_transforms=input_transforms, input_transforms=input_transforms,
@ -243,10 +285,7 @@ class LibriSpeechAsrDataModule(DataModule):
return train_dl return train_dl
def valid_dataloaders(self) -> DataLoader: def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
transforms = [] transforms = []
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
transforms = [ transforms = [
@ -285,75 +324,63 @@ class LibriSpeechAsrDataModule(DataModule):
return valid_dl return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
cuts = self.test_cuts() logging.debug("About to create test dataset")
is_list = isinstance(cuts, list) test = K2SpeechRecognitionDataset(
test_loaders = [] input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if not is_list: if self.args.on_the_fly_feats
cuts = [cuts] else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
test_loaders.append(test_dl)
if is_list:
return test_loaders
else:
return test_loaders[0]
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
) )
if self.args.full_libri: sampler = BucketingSampler(
cuts_train = ( cuts, max_duration=self.args.max_duration, shuffle=False
cuts_train )
+ load_manifest( logging.debug("About to create test dataloader")
self.args.feature_dir / "cuts_train-clean-360.json.gz" test_dl = DataLoader(
) test,
+ load_manifest( batch_size=None,
self.args.feature_dir / "cuts_train-other-500.json.gz" sampler=sampler,
) num_workers=self.args.num_workers,
) )
return cuts_train return test_dl
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get train-clean-100 cuts")
cuts_valid = load_manifest( return load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz" self.args.manifest_dir / "cuts_train-clean-100.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") )
return cuts_valid
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def train_clean_360_cuts(self) -> CutSet:
test_sets = ["test-clean", "test-other"] logging.info("About to get train-clean-360 cuts")
cuts = [] return load_manifest(
for test_set in test_sets: self.args.manifest_dir / "cuts_train-clean-360.json.gz"
logging.debug("About to get test cuts") )
cuts.append(
load_manifest( @lru_cache()
self.args.feature_dir / f"cuts_{test_set}.json.gz" def train_other_500_cuts(self) -> CutSet:
) logging.info("About to get train-other-500 cuts")
) return load_manifest(
return cuts self.args.manifest_dir / "cuts_train-other-500.json.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz")
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz")
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz")
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz")

View File

@ -474,14 +474,17 @@ def main():
model.eval() model.eval()
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip test_clean_cuts = librispeech.test_clean_cuts()
# it inside the for loop. That is, use test_other_cuts = librispeech.test_other_cuts()
#
# if test_set == 'test-clean': continue test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -532,8 +532,16 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)

View File

@ -85,6 +85,7 @@ def load_checkpoint(
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None, scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional[GradScaler] = None,
strict: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
TODO: document it TODO: document it
@ -101,9 +102,9 @@ def load_checkpoint(
src_key = "{}.{}".format("module", key) src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key) dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0 assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=False) model.load_state_dict(dst_state_dict, strict=strict)
else: else:
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=strict)
checkpoint.pop("model") checkpoint.pop("model")

View File

@ -306,7 +306,10 @@ def get_texts(
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs. """Extract the token IDs (from best_paths.tokens) from the best-path FSAs.
Caution:
There are no repeats in `best_paths.tokens`.
Args: Args:
best_paths: best_paths:
@ -320,11 +323,11 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
after subsampling of the i-th utterance in the batch. after subsampling of the i-th utterance in the batch.
""" """
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
label_shape = best_paths.arcs.shape().remove_axis(1) token_shape = best_paths.arcs.shape().remove_axis(1)
# label_shape has axes [fsa][arc] # token_shape has axes [fsa][arc]
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous()) tokens = k2.RaggedTensor(token_shape, best_paths.tokens)
labels = labels.remove_values_eq(-1) tokens = tokens.remove_values_eq(-1)
return labels.tolist() return tokens.tolist()
def save_alignments( def save_alignments(

View File

@ -44,180 +44,47 @@ from icefall.utils import get_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent ICEFALL_DIR = Path(__file__).resolve().parent.parent
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR" egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
lang_dir = egs_dir / "data/lang_bpe_500" lang_dir = egs_dir / "data/lang_bpe_500"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz" cuts_json = egs_dir / "data/token_ali/cuts_test-clean.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
def data_exists(): def data_exists():
return ali_filename.exists() and cut_json.exists() and lang_dir.exists() return cuts_json.exists() and lang_dir.exists()
def get_dataloader(): def get_dataloader():
cuts_train = load_manifest(cut_json) cuts = load_manifest(cuts_json)
cuts_train = cuts_train.with_features_path_prefix(egs_dir) cuts = cuts.with_features_path_prefix(egs_dir)
train_sampler = SingleCutSampler( sampler = SingleCutSampler(
cuts_train, cuts,
max_duration=40, max_duration=40,
shuffle=False, shuffle=False,
) )
train = K2SpeechRecognitionDataset(return_cuts=True) dataset = K2SpeechRecognitionDataset(return_cuts=True)
train_dl = DataLoader( dl = DataLoader(
train, dataset,
sampler=train_sampler, sampler=sampler,
batch_size=None, batch_size=None,
num_workers=1, num_workers=1,
persistent_workers=False, persistent_workers=False,
) )
return train_dl return dl
def test_one_hot():
a = [1, 3, 2]
b = [1, 0, 4, 2]
c = [torch.tensor(a), torch.tensor(b)]
d = pad_sequence(c, batch_first=True, padding_value=0)
f = torch.nn.functional.one_hot(d, num_classes=5)
e = (1 - f) * -10.0
expected = torch.tensor(
[
[
[-10, 0, -10, -10, -10],
[-10, -10, -10, 0, -10],
[-10, -10, 0, -10, -10],
[0, -10, -10, -10, -10],
],
[
[-10, 0, -10, -10, -10],
[0, -10, -10, -10, -10],
[-10, -10, -10, -10, 0],
[-10, -10, 0, -10, -10],
],
]
).to(e.dtype)
assert torch.all(torch.eq(e, expected))
def test(): def test():
"""
The purpose of this test is to show that we can use pre-computed
alignments to construct a mask, adding it to a randomly generated
nnet_output, to decode the correct transcript from the resulting
nnet_output.
"""
if not data_exists(): if not data_exists():
return return
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
dl = get_dataloader() dl = get_dataloader()
subsampling_factor, ali = load_alignments(ali_filename)
ali = convert_alignments_to_tensor(ali, device=device)
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
word_table = lexicon.word_table
HLG = k2.Fsa.from_dict(
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
)
for batch in dl: for batch in dl:
features = batch["inputs"]
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
N = features.shape[0] cuts = supervisions["cut"]
T = features.shape[1] // subsampling_factor print(cuts)
nnet_output = (
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
.softmax(dim=-1)
.log()
)
cut_ids = [cut.id for cut in supervisions["cut"]]
mask = lookup_alignments(
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
)
min_len = min(nnet_output.shape[1], mask.shape[1])
ali_model_scale = 0.8
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor,
supervisions["num_frames"] // subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=10000,
subsampling_factor=subsampling_factor,
)
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
hyps = [" ".join(s) for s in hyps]
print(hyps)
print(supervisions["text"])
break break
def show_cut_ids():
# The purpose of this function is to check that
# for each utterance in the training set, there is
# a corresponding alignment.
#
# After generating a1.txt and b1.txt
# You can use
# wc -l a1.txt b1.txt
# which should show the same number of lines.
#
# cat a1.txt | sort | uniq > a11.txt
# cat b1.txt | sort | uniq > b11.txt
#
# md5sum a11.txt b11.txt
# which should show the identical hash
#
# diff a11.txt b11.txt
# should print nothing
subsampling_factor, ali = load_alignments(ali_filename)
with open("a1.txt", "w") as f:
for key in ali:
f.write(f"{key}\n")
# dl = get_dataloader()
cuts_train = (
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
)
ans = []
for cut in cuts_train:
ans.append(cut.id)
with open("b1.txt", "w") as f:
for line in ans:
f.write(f"{line}\n")
if __name__ == "__main__": if __name__ == "__main__":
test() test()