Use pre-computed alignments in LF-MMI training.

This commit is contained in:
Fangjun Kuang 2021-09-28 15:37:47 +08:00
parent 9e6bd0f07c
commit 94daaee6ba
4 changed files with 446 additions and 2 deletions

View File

@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule):
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "

View File

@ -21,7 +21,7 @@ import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
from typing import Dict, Optional
import k2
import torch
@ -36,6 +36,11 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
@ -93,6 +98,17 @@ def get_parser():
""",
)
parser.add_argument(
"--ali-dir",
type=str,
default="data/ali_500",
help="""This folder is expected to contain
two files, train-960.pt and valid.pt, which
contain framewise alignment information for
the training set and validation set.
""",
)
return parser
@ -284,6 +300,7 @@ def compute_loss(
batch: dict,
graph_compiler: MmiTrainingGraphCompiler,
is_training: bool,
ali: Optional[Dict[str, torch.Tensor]],
):
"""
Compute LF-MMI loss given the model and its inputs.
@ -304,6 +321,8 @@ def compute_loss(
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
ali:
Precomputed alignments.
"""
device = graph_compiler.device
feature = batch["inputs"]
@ -323,6 +342,30 @@ def compute_loss(
supervisions, subsampling_factor=params.subsampling_factor
)
if ali is not None and params.batch_idx_train < 4000:
cut_ids = [cut.id for cut in supervisions["cut"]]
# As encode_supervisions reorders cuts, we need
# also to reorder cut IDs here
new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]
# Check that new2old is just a permutation,
# i.e., each cut contains only one utterance
new2old.sort()
assert new2old == torch.arange(len(new2old)).tolist()
mask = lookup_alignments(
cut_ids=cut_ids,
alignments=ali,
num_classes=nnet_output.shape[2],
).to(nnet_output)
min_len = min(nnet_output.shape[1], mask.shape[1])
ali_scale = 500.0 / (params.batch_idx_train + 500)
nnet_output = nnet_output.clone()
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]
loss_fn = LFMMILoss(
graph_compiler=graph_compiler,
use_pruned_intersect=params.use_pruned_intersect,
@ -377,6 +420,7 @@ def compute_validation_loss(
graph_compiler: MmiTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
ali: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
@ -394,6 +438,7 @@ def compute_validation_loss(
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
ali=ali,
)
assert loss.requires_grad is False
assert mmi_loss.requires_grad is False
@ -435,6 +480,8 @@ def train_one_epoch(
graph_compiler: MmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
train_ali: Optional[Dict[str, torch.Tensor]],
valid_ali: Optional[Dict[str, torch.Tensor]],
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
@ -457,6 +504,10 @@ def train_one_epoch(
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
train_ali:
Precomputed alignments for the training set.
valid_ali:
Precomputed alignments for the validation set.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
@ -481,6 +532,7 @@ def train_one_epoch(
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
ali=train_ali,
)
# NOTE: We use reduction==sum and loss is computed over utterances
@ -565,6 +617,7 @@ def train_one_epoch(
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
ali=valid_ali,
)
model.train()
logging.info(
@ -673,12 +726,34 @@ def run(rank, world_size, args):
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
train_960_ali_filename = Path(params.ali_dir) / "train-960.pt"
if params.batch_idx_train < 4000 and train_960_ali_filename.is_file():
logging.info("Use pre-computed alignments")
subsampling_factor, train_ali = load_alignments(train_960_ali_filename)
assert subsampling_factor == params.subsampling_factor
assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723"
valid_ali_filename = Path(params.ali_dir) / "valid.pt"
subsampling_factor, valid_ali = load_alignments(valid_ali_filename)
assert subsampling_factor == params.subsampling_factor
train_ali = convert_alignments_to_tensor(train_ali, device=device)
valid_ali = convert_alignments_to_tensor(valid_ali, device=device)
else:
logging.info("Not using alignments")
train_ali = None
valid_ali = None
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
if params.batch_idx_train > 4000 and train_ali is not None:
# Delete the alignments to save memory
train_ali = None
valid_ali = None
cur_lr = optimizer._rate
if tb_writer is not None:
@ -699,6 +774,8 @@ def run(rank, world_size, args):
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
train_ali=train_ali,
valid_ali=valid_ali,
tb_writer=tb_writer,
world_size=world_size,
)

142
icefall/ali.py Normal file
View File

@ -0,0 +1,142 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
def save_alignments(
alignments: Dict[str, List[int]],
subsampling_factor: int,
filename: str,
) -> None:
"""Save alignments to a file.
Args:
alignments:
A dict containing alignments. Keys of the dict are utterances and
values are the corresponding framewise alignments after subsampling.
subsampling_factor:
The subsampling factor of the model.
filename:
Path to save the alignments.
Returns:
Return None.
"""
ali_dict = {
"subsampling_factor": subsampling_factor,
"alignments": alignments,
}
torch.save(ali_dict, filename)
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
"""Load alignments from a file.
Args:
filename:
Path to the file containing alignment information.
The file should be saved by :func:`save_alignments`.
Returns:
Return a tuple containing:
- subsampling_factor: The subsampling_factor used to compute
the alignments.
- alignments: A dict containing utterances and their corresponding
framewise alignment, after subsampling.
"""
ali_dict = torch.load(filename)
subsampling_factor = ali_dict["subsampling_factor"]
alignments = ali_dict["alignments"]
return subsampling_factor, alignments
def convert_alignments_to_tensor(
alignments: Dict[str, List[int]], device: torch.device
) -> Dict[str, torch.Tensor]:
"""Convert alignments from list of int to a 1-D torch.Tensor.
Args:
alignments:
A dict containing alignments. Keys are utterance IDs and
values are their corresponding frame-wise alignments.
device:
The device to move the alignments to.
Returns:
Return a dict using 1-D torch.Tensor to store the alignments.
The dtype of the tensor are `torch.int64`. We choose `torch.int64`
because `torch.nn.functional.one_hot` requires that.
"""
ans = {}
for utt_id, ali in alignments.items():
ali = torch.tensor(ali, dtype=torch.int64, device=device)
ans[utt_id] = ali
return ans
def lookup_alignments(
cut_ids: List[str],
alignments: Dict[str, torch.Tensor],
num_classes: int,
log_score: float = -10,
) -> torch.Tensor:
"""Return a mask constructed from alignments by a list of cut IDs.
The returned mask is a 3-D tensor of shape (N, T, C). For each frame,
i.e., each row, of the returned mask, positions not corresponding to
the alignments are filled with `log_score`, while the position
specified by the alignment is filled with 0. For instance, if the alignments
of two utterances are:
[ [1, 3, 2], [1, 0, 4, 2] ]
num_classes is 5 and log_score is -10, then the returned mask is
[
[[-10, 0, -10, -10, -10],
[-10, -10, -10, 0, -10],
[-10, -10, 0, -10, -10],
[0, -10, -10, -10, -10]],
[[-10, 0, -10, -10, -10],
[0, -10, -10, -10, -10],
[-10, -10, -10, -10, 0],
[-10, -10, 0, -10, -10]]
]
Note: We pad the alignment of the first utterance with 0.
Args:
cut_ids:
A list of utterance IDs.
alignments:
A dict containing alignments. The keys are utterance IDs and the values
are framewise alignments.
num_classes:
The max token ID + 1 that appears in the alignments.
log_score:
Positions in the returned tensor not corresponding to the alignments
are filled with this value.
Returns:
Return a 3-D torch.float32 tensor of shape (N, T, C).
"""
# We assume all utterances have their alignments.
ali = [alignments[cut_id] for cut_id in cut_ids]
padded_ali = pad_sequence(ali, batch_first=True, padding_value=0)
padded_one_hot = torch.nn.functional.one_hot(
padded_ali,
num_classes=num_classes,
)
mask = (1 - padded_one_hot) * float(log_score)
return mask

223
test/test_ali.py Executable file
View File

@ -0,0 +1,223 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Runt his file using one of the following two ways:
# (1) python3 ./test/test_ali.py
# (2) pytest ./test/test_ali.py
# The purpose of this file is to show that if we build a mask
# from alignments and add it to a randomly generated nnet_output,
# we can decode the correct transcript.
from pathlib import Path
import k2
import torch
from lhotse import load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import get_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
lang_dir = egs_dir / "data/lang_bpe_500"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
# cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
def data_exists():
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
def get_dataloader():
cuts_train = load_manifest(cut_json)
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
train_sampler = SingleCutSampler(
cuts_train,
max_duration=200,
shuffle=False,
)
train = K2SpeechRecognitionDataset(return_cuts=True)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=1,
persistent_workers=False,
)
return train_dl
def test_one_hot():
a = [1, 3, 2]
b = [1, 0, 4, 2]
c = [torch.tensor(a), torch.tensor(b)]
d = pad_sequence(c, batch_first=True, padding_value=0)
f = torch.nn.functional.one_hot(d, num_classes=5)
e = (1 - f) * -10.0
expected = torch.tensor(
[
[
[-10, 0, -10, -10, -10],
[-10, -10, -10, 0, -10],
[-10, -10, 0, -10, -10],
[0, -10, -10, -10, -10],
],
[
[-10, 0, -10, -10, -10],
[0, -10, -10, -10, -10],
[-10, -10, -10, -10, 0],
[-10, -10, 0, -10, -10],
],
]
).to(e.dtype)
assert torch.all(torch.eq(e, expected))
def test():
"""
The purpose of this test is to show that we can use pre-computed
alignments to construct a mask, adding it to a randomly generated
nnet_output, to decode the correct transcript from the resulting
nnet_output.
"""
if not data_exists():
return
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
dl = get_dataloader()
subsampling_factor, ali = load_alignments(ali_filename)
ali = convert_alignments_to_tensor(ali, device=device)
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
word_table = lexicon.word_table
HLG = k2.Fsa.from_dict(
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
)
for batch in dl:
features = batch["inputs"]
supervisions = batch["supervisions"]
N = features.shape[0]
T = features.shape[1] // subsampling_factor
nnet_output = (
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
.softmax(dim=-1)
.log()
)
cut_ids = [cut.id for cut in supervisions["cut"]]
mask = lookup_alignments(
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
)
min_len = min(nnet_output.shape[1], mask.shape[1])
ali_model_scale = 0.8
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor,
supervisions["num_frames"] // subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=10000,
subsampling_factor=subsampling_factor,
)
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
hyps = [" ".join(s) for s in hyps]
print(hyps)
print(supervisions["text"])
break
def show_cut_ids():
# The purpose of this function is to check that
# for each utterance in the training set, there is
# a corresponding alignment.
#
# After generating a1.txt and b1.txt
# You can use
# wc -l a1.txt b1.txt
# which should show the same number of lines.
#
# cat a1.txt | sort | uniq > a11.txt
# cat b1.txt | sort | uniq > b11.txt
#
# md5sum a11.txt b11.txt
# which should show the identical hash
#
# diff a11.txt b11.txt
# should print nothing
subsampling_factor, ali = load_alignments(ali_filename)
with open("a1.txt", "w") as f:
for key in ali:
f.write(f"{key}\n")
# dl = get_dataloader()
cuts_train = (
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
)
ans = []
for cut in cuts_train:
ans.append(cut.id)
with open("b1.txt", "w") as f:
for line in ans:
f.write(f"{line}\n")
if __name__ == "__main__":
test()