Use pre-computed alignments in LF-MMI training.

This commit is contained in:
Fangjun Kuang 2021-09-28 15:37:47 +08:00
parent 9e6bd0f07c
commit 94daaee6ba
4 changed files with 446 additions and 2 deletions

View File

@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule):
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset") logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "

View File

@ -21,7 +21,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Dict, Optional
import k2 import k2
import torch import torch
@ -36,6 +36,11 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
@ -93,6 +98,17 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--ali-dir",
type=str,
default="data/ali_500",
help="""This folder is expected to contain
two files, train-960.pt and valid.pt, which
contain framewise alignment information for
the training set and validation set.
""",
)
return parser return parser
@ -284,6 +300,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
is_training: bool, is_training: bool,
ali: Optional[Dict[str, torch.Tensor]],
): ):
""" """
Compute LF-MMI loss given the model and its inputs. Compute LF-MMI loss given the model and its inputs.
@ -304,6 +321,8 @@ def compute_loss(
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
ali:
Precomputed alignments.
""" """
device = graph_compiler.device device = graph_compiler.device
feature = batch["inputs"] feature = batch["inputs"]
@ -323,6 +342,30 @@ def compute_loss(
supervisions, subsampling_factor=params.subsampling_factor supervisions, subsampling_factor=params.subsampling_factor
) )
if ali is not None and params.batch_idx_train < 4000:
cut_ids = [cut.id for cut in supervisions["cut"]]
# As encode_supervisions reorders cuts, we need
# also to reorder cut IDs here
new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]
# Check that new2old is just a permutation,
# i.e., each cut contains only one utterance
new2old.sort()
assert new2old == torch.arange(len(new2old)).tolist()
mask = lookup_alignments(
cut_ids=cut_ids,
alignments=ali,
num_classes=nnet_output.shape[2],
).to(nnet_output)
min_len = min(nnet_output.shape[1], mask.shape[1])
ali_scale = 500.0 / (params.batch_idx_train + 500)
nnet_output = nnet_output.clone()
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]
loss_fn = LFMMILoss( loss_fn = LFMMILoss(
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
use_pruned_intersect=params.use_pruned_intersect, use_pruned_intersect=params.use_pruned_intersect,
@ -377,6 +420,7 @@ def compute_validation_loss(
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
ali: Optional[Dict[str, torch.Tensor]] = None,
) -> None: ) -> None:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
@ -394,6 +438,7 @@ def compute_validation_loss(
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=False, is_training=False,
ali=ali,
) )
assert loss.requires_grad is False assert loss.requires_grad is False
assert mmi_loss.requires_grad is False assert mmi_loss.requires_grad is False
@ -435,6 +480,8 @@ def train_one_epoch(
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
train_ali: Optional[Dict[str, torch.Tensor]],
valid_ali: Optional[Dict[str, torch.Tensor]],
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> None:
@ -457,6 +504,10 @@ def train_one_epoch(
Dataloader for the training dataset. Dataloader for the training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
train_ali:
Precomputed alignments for the training set.
valid_ali:
Precomputed alignments for the validation set.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -481,6 +532,7 @@ def train_one_epoch(
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
ali=train_ali,
) )
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
@ -565,6 +617,7 @@ def train_one_epoch(
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
ali=valid_ali,
) )
model.train() model.train()
logging.info( logging.info(
@ -673,12 +726,34 @@ def run(rank, world_size, args):
if checkpoints: if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
train_960_ali_filename = Path(params.ali_dir) / "train-960.pt"
if params.batch_idx_train < 4000 and train_960_ali_filename.is_file():
logging.info("Use pre-computed alignments")
subsampling_factor, train_ali = load_alignments(train_960_ali_filename)
assert subsampling_factor == params.subsampling_factor
assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723"
valid_ali_filename = Path(params.ali_dir) / "valid.pt"
subsampling_factor, valid_ali = load_alignments(valid_ali_filename)
assert subsampling_factor == params.subsampling_factor
train_ali = convert_alignments_to_tensor(train_ali, device=device)
valid_ali = convert_alignments_to_tensor(valid_ali, device=device)
else:
logging.info("Not using alignments")
train_ali = None
valid_ali = None
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders() train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
if params.batch_idx_train > 4000 and train_ali is not None:
# Delete the alignments to save memory
train_ali = None
valid_ali = None
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
@ -699,6 +774,8 @@ def run(rank, world_size, args):
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
train_ali=train_ali,
valid_ali=valid_ali,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
) )

142
icefall/ali.py Normal file
View File

@ -0,0 +1,142 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
def save_alignments(
alignments: Dict[str, List[int]],
subsampling_factor: int,
filename: str,
) -> None:
"""Save alignments to a file.
Args:
alignments:
A dict containing alignments. Keys of the dict are utterances and
values are the corresponding framewise alignments after subsampling.
subsampling_factor:
The subsampling factor of the model.
filename:
Path to save the alignments.
Returns:
Return None.
"""
ali_dict = {
"subsampling_factor": subsampling_factor,
"alignments": alignments,
}
torch.save(ali_dict, filename)
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
"""Load alignments from a file.
Args:
filename:
Path to the file containing alignment information.
The file should be saved by :func:`save_alignments`.
Returns:
Return a tuple containing:
- subsampling_factor: The subsampling_factor used to compute
the alignments.
- alignments: A dict containing utterances and their corresponding
framewise alignment, after subsampling.
"""
ali_dict = torch.load(filename)
subsampling_factor = ali_dict["subsampling_factor"]
alignments = ali_dict["alignments"]
return subsampling_factor, alignments
def convert_alignments_to_tensor(
alignments: Dict[str, List[int]], device: torch.device
) -> Dict[str, torch.Tensor]:
"""Convert alignments from list of int to a 1-D torch.Tensor.
Args:
alignments:
A dict containing alignments. Keys are utterance IDs and
values are their corresponding frame-wise alignments.
device:
The device to move the alignments to.
Returns:
Return a dict using 1-D torch.Tensor to store the alignments.
The dtype of the tensor are `torch.int64`. We choose `torch.int64`
because `torch.nn.functional.one_hot` requires that.
"""
ans = {}
for utt_id, ali in alignments.items():
ali = torch.tensor(ali, dtype=torch.int64, device=device)
ans[utt_id] = ali
return ans
def lookup_alignments(
cut_ids: List[str],
alignments: Dict[str, torch.Tensor],
num_classes: int,
log_score: float = -10,
) -> torch.Tensor:
"""Return a mask constructed from alignments by a list of cut IDs.
The returned mask is a 3-D tensor of shape (N, T, C). For each frame,
i.e., each row, of the returned mask, positions not corresponding to
the alignments are filled with `log_score`, while the position
specified by the alignment is filled with 0. For instance, if the alignments
of two utterances are:
[ [1, 3, 2], [1, 0, 4, 2] ]
num_classes is 5 and log_score is -10, then the returned mask is
[
[[-10, 0, -10, -10, -10],
[-10, -10, -10, 0, -10],
[-10, -10, 0, -10, -10],
[0, -10, -10, -10, -10]],
[[-10, 0, -10, -10, -10],
[0, -10, -10, -10, -10],
[-10, -10, -10, -10, 0],
[-10, -10, 0, -10, -10]]
]
Note: We pad the alignment of the first utterance with 0.
Args:
cut_ids:
A list of utterance IDs.
alignments:
A dict containing alignments. The keys are utterance IDs and the values
are framewise alignments.
num_classes:
The max token ID + 1 that appears in the alignments.
log_score:
Positions in the returned tensor not corresponding to the alignments
are filled with this value.
Returns:
Return a 3-D torch.float32 tensor of shape (N, T, C).
"""
# We assume all utterances have their alignments.
ali = [alignments[cut_id] for cut_id in cut_ids]
padded_ali = pad_sequence(ali, batch_first=True, padding_value=0)
padded_one_hot = torch.nn.functional.one_hot(
padded_ali,
num_classes=num_classes,
)
mask = (1 - padded_one_hot) * float(log_score)
return mask

223
test/test_ali.py Executable file
View File

@ -0,0 +1,223 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Runt his file using one of the following two ways:
# (1) python3 ./test/test_ali.py
# (2) pytest ./test/test_ali.py
# The purpose of this file is to show that if we build a mask
# from alignments and add it to a randomly generated nnet_output,
# we can decode the correct transcript.
from pathlib import Path
import k2
import torch
from lhotse import load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import get_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
lang_dir = egs_dir / "data/lang_bpe_500"
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
# cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
def data_exists():
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
def get_dataloader():
cuts_train = load_manifest(cut_json)
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
train_sampler = SingleCutSampler(
cuts_train,
max_duration=200,
shuffle=False,
)
train = K2SpeechRecognitionDataset(return_cuts=True)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=1,
persistent_workers=False,
)
return train_dl
def test_one_hot():
a = [1, 3, 2]
b = [1, 0, 4, 2]
c = [torch.tensor(a), torch.tensor(b)]
d = pad_sequence(c, batch_first=True, padding_value=0)
f = torch.nn.functional.one_hot(d, num_classes=5)
e = (1 - f) * -10.0
expected = torch.tensor(
[
[
[-10, 0, -10, -10, -10],
[-10, -10, -10, 0, -10],
[-10, -10, 0, -10, -10],
[0, -10, -10, -10, -10],
],
[
[-10, 0, -10, -10, -10],
[0, -10, -10, -10, -10],
[-10, -10, -10, -10, 0],
[-10, -10, 0, -10, -10],
],
]
).to(e.dtype)
assert torch.all(torch.eq(e, expected))
def test():
"""
The purpose of this test is to show that we can use pre-computed
alignments to construct a mask, adding it to a randomly generated
nnet_output, to decode the correct transcript from the resulting
nnet_output.
"""
if not data_exists():
return
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
dl = get_dataloader()
subsampling_factor, ali = load_alignments(ali_filename)
ali = convert_alignments_to_tensor(ali, device=device)
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
word_table = lexicon.word_table
HLG = k2.Fsa.from_dict(
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
)
for batch in dl:
features = batch["inputs"]
supervisions = batch["supervisions"]
N = features.shape[0]
T = features.shape[1] // subsampling_factor
nnet_output = (
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
.softmax(dim=-1)
.log()
)
cut_ids = [cut.id for cut in supervisions["cut"]]
mask = lookup_alignments(
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
)
min_len = min(nnet_output.shape[1], mask.shape[1])
ali_model_scale = 0.8
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor,
supervisions["num_frames"] // subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=10000,
subsampling_factor=subsampling_factor,
)
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
hyps = [" ".join(s) for s in hyps]
print(hyps)
print(supervisions["text"])
break
def show_cut_ids():
# The purpose of this function is to check that
# for each utterance in the training set, there is
# a corresponding alignment.
#
# After generating a1.txt and b1.txt
# You can use
# wc -l a1.txt b1.txt
# which should show the same number of lines.
#
# cat a1.txt | sort | uniq > a11.txt
# cat b1.txt | sort | uniq > b11.txt
#
# md5sum a11.txt b11.txt
# which should show the identical hash
#
# diff a11.txt b11.txt
# should print nothing
subsampling_factor, ali = load_alignments(ali_filename)
with open("a1.txt", "w") as f:
for key in ali:
f.write(f"{key}\n")
# dl = get_dataloader()
cuts_train = (
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
)
ans = []
for cut in cuts_train:
ans.append(cut.id)
with open("b1.txt", "w") as f:
for line in ans:
f.write(f"{line}\n")
if __name__ == "__main__":
test()