mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use pre-computed alignments in LF-MMI training.
This commit is contained in:
parent
9e6bd0f07c
commit
94daaee6ba
@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
|
||||
transforms = [
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
]
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
|
@ -21,7 +21,7 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -36,6 +36,11 @@ from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
@ -93,6 +98,17 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ali-dir",
|
||||
type=str,
|
||||
default="data/ali_500",
|
||||
help="""This folder is expected to contain
|
||||
two files, train-960.pt and valid.pt, which
|
||||
contain framewise alignment information for
|
||||
the training set and validation set.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -284,6 +300,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
ali: Optional[Dict[str, torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Compute LF-MMI loss given the model and its inputs.
|
||||
@ -304,6 +321,8 @@ def compute_loss(
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
ali:
|
||||
Precomputed alignments.
|
||||
"""
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
@ -323,6 +342,30 @@ def compute_loss(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
|
||||
if ali is not None and params.batch_idx_train < 4000:
|
||||
cut_ids = [cut.id for cut in supervisions["cut"]]
|
||||
|
||||
# As encode_supervisions reorders cuts, we need
|
||||
# also to reorder cut IDs here
|
||||
new2old = supervision_segments[:, 0].tolist()
|
||||
cut_ids = [cut_ids[i] for i in new2old]
|
||||
|
||||
# Check that new2old is just a permutation,
|
||||
# i.e., each cut contains only one utterance
|
||||
new2old.sort()
|
||||
assert new2old == torch.arange(len(new2old)).tolist()
|
||||
mask = lookup_alignments(
|
||||
cut_ids=cut_ids,
|
||||
alignments=ali,
|
||||
num_classes=nnet_output.shape[2],
|
||||
).to(nnet_output)
|
||||
|
||||
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||
ali_scale = 500.0 / (params.batch_idx_train + 500)
|
||||
|
||||
nnet_output = nnet_output.clone()
|
||||
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]
|
||||
|
||||
loss_fn = LFMMILoss(
|
||||
graph_compiler=graph_compiler,
|
||||
use_pruned_intersect=params.use_pruned_intersect,
|
||||
@ -377,6 +420,7 @@ def compute_validation_loss(
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
ali: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> None:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
@ -394,6 +438,7 @@ def compute_validation_loss(
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=False,
|
||||
ali=ali,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
assert mmi_loss.requires_grad is False
|
||||
@ -435,6 +480,8 @@ def train_one_epoch(
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
train_ali: Optional[Dict[str, torch.Tensor]],
|
||||
valid_ali: Optional[Dict[str, torch.Tensor]],
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
@ -457,6 +504,10 @@ def train_one_epoch(
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
train_ali:
|
||||
Precomputed alignments for the training set.
|
||||
valid_ali:
|
||||
Precomputed alignments for the validation set.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
@ -481,6 +532,7 @@ def train_one_epoch(
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
ali=train_ali,
|
||||
)
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
@ -565,6 +617,7 @@ def train_one_epoch(
|
||||
graph_compiler=graph_compiler,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
ali=valid_ali,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
@ -673,12 +726,34 @@ def run(rank, world_size, args):
|
||||
if checkpoints:
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
train_960_ali_filename = Path(params.ali_dir) / "train-960.pt"
|
||||
if params.batch_idx_train < 4000 and train_960_ali_filename.is_file():
|
||||
logging.info("Use pre-computed alignments")
|
||||
subsampling_factor, train_ali = load_alignments(train_960_ali_filename)
|
||||
assert subsampling_factor == params.subsampling_factor
|
||||
assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723"
|
||||
|
||||
valid_ali_filename = Path(params.ali_dir) / "valid.pt"
|
||||
subsampling_factor, valid_ali = load_alignments(valid_ali_filename)
|
||||
assert subsampling_factor == params.subsampling_factor
|
||||
|
||||
train_ali = convert_alignments_to_tensor(train_ali, device=device)
|
||||
valid_ali = convert_alignments_to_tensor(valid_ali, device=device)
|
||||
else:
|
||||
logging.info("Not using alignments")
|
||||
train_ali = None
|
||||
valid_ali = None
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
if params.batch_idx_train > 4000 and train_ali is not None:
|
||||
# Delete the alignments to save memory
|
||||
train_ali = None
|
||||
valid_ali = None
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
@ -699,6 +774,8 @@ def run(rank, world_size, args):
|
||||
graph_compiler=graph_compiler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
train_ali=train_ali,
|
||||
valid_ali=valid_ali,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
142
icefall/ali.py
Normal file
142
icefall/ali.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def save_alignments(
|
||||
alignments: Dict[str, List[int]],
|
||||
subsampling_factor: int,
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""Save alignments to a file.
|
||||
|
||||
Args:
|
||||
alignments:
|
||||
A dict containing alignments. Keys of the dict are utterances and
|
||||
values are the corresponding framewise alignments after subsampling.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
filename:
|
||||
Path to save the alignments.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
ali_dict = {
|
||||
"subsampling_factor": subsampling_factor,
|
||||
"alignments": alignments,
|
||||
}
|
||||
torch.save(ali_dict, filename)
|
||||
|
||||
|
||||
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
||||
"""Load alignments from a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Path to the file containing alignment information.
|
||||
The file should be saved by :func:`save_alignments`.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- subsampling_factor: The subsampling_factor used to compute
|
||||
the alignments.
|
||||
- alignments: A dict containing utterances and their corresponding
|
||||
framewise alignment, after subsampling.
|
||||
"""
|
||||
ali_dict = torch.load(filename)
|
||||
subsampling_factor = ali_dict["subsampling_factor"]
|
||||
alignments = ali_dict["alignments"]
|
||||
return subsampling_factor, alignments
|
||||
|
||||
|
||||
def convert_alignments_to_tensor(
|
||||
alignments: Dict[str, List[int]], device: torch.device
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert alignments from list of int to a 1-D torch.Tensor.
|
||||
|
||||
Args:
|
||||
alignments:
|
||||
A dict containing alignments. Keys are utterance IDs and
|
||||
values are their corresponding frame-wise alignments.
|
||||
device:
|
||||
The device to move the alignments to.
|
||||
Returns:
|
||||
Return a dict using 1-D torch.Tensor to store the alignments.
|
||||
The dtype of the tensor are `torch.int64`. We choose `torch.int64`
|
||||
because `torch.nn.functional.one_hot` requires that.
|
||||
"""
|
||||
ans = {}
|
||||
for utt_id, ali in alignments.items():
|
||||
ali = torch.tensor(ali, dtype=torch.int64, device=device)
|
||||
ans[utt_id] = ali
|
||||
return ans
|
||||
|
||||
|
||||
def lookup_alignments(
|
||||
cut_ids: List[str],
|
||||
alignments: Dict[str, torch.Tensor],
|
||||
num_classes: int,
|
||||
log_score: float = -10,
|
||||
) -> torch.Tensor:
|
||||
"""Return a mask constructed from alignments by a list of cut IDs.
|
||||
|
||||
The returned mask is a 3-D tensor of shape (N, T, C). For each frame,
|
||||
i.e., each row, of the returned mask, positions not corresponding to
|
||||
the alignments are filled with `log_score`, while the position
|
||||
specified by the alignment is filled with 0. For instance, if the alignments
|
||||
of two utterances are:
|
||||
|
||||
[ [1, 3, 2], [1, 0, 4, 2] ]
|
||||
num_classes is 5 and log_score is -10, then the returned mask is
|
||||
|
||||
[
|
||||
[[-10, 0, -10, -10, -10],
|
||||
[-10, -10, -10, 0, -10],
|
||||
[-10, -10, 0, -10, -10],
|
||||
[0, -10, -10, -10, -10]],
|
||||
[[-10, 0, -10, -10, -10],
|
||||
[0, -10, -10, -10, -10],
|
||||
[-10, -10, -10, -10, 0],
|
||||
[-10, -10, 0, -10, -10]]
|
||||
]
|
||||
Note: We pad the alignment of the first utterance with 0.
|
||||
|
||||
Args:
|
||||
cut_ids:
|
||||
A list of utterance IDs.
|
||||
alignments:
|
||||
A dict containing alignments. The keys are utterance IDs and the values
|
||||
are framewise alignments.
|
||||
num_classes:
|
||||
The max token ID + 1 that appears in the alignments.
|
||||
log_score:
|
||||
Positions in the returned tensor not corresponding to the alignments
|
||||
are filled with this value.
|
||||
Returns:
|
||||
Return a 3-D torch.float32 tensor of shape (N, T, C).
|
||||
"""
|
||||
# We assume all utterances have their alignments.
|
||||
ali = [alignments[cut_id] for cut_id in cut_ids]
|
||||
padded_ali = pad_sequence(ali, batch_first=True, padding_value=0)
|
||||
padded_one_hot = torch.nn.functional.one_hot(
|
||||
padded_ali,
|
||||
num_classes=num_classes,
|
||||
)
|
||||
mask = (1 - padded_one_hot) * float(log_score)
|
||||
return mask
|
223
test/test_ali.py
Executable file
223
test/test_ali.py
Executable file
@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Runt his file using one of the following two ways:
|
||||
# (1) python3 ./test/test_ali.py
|
||||
# (2) pytest ./test/test_ali.py
|
||||
|
||||
# The purpose of this file is to show that if we build a mask
|
||||
# from alignments and add it to a randomly generated nnet_output,
|
||||
# we can decode the correct transcript.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from lhotse import load_manifest
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import get_texts
|
||||
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
|
||||
lang_dir = egs_dir / "data/lang_bpe_500"
|
||||
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
|
||||
cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
|
||||
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
|
||||
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
|
||||
|
||||
# cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
|
||||
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
|
||||
|
||||
|
||||
def data_exists():
|
||||
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
|
||||
|
||||
|
||||
def get_dataloader():
|
||||
cuts_train = load_manifest(cut_json)
|
||||
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=200,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
train = K2SpeechRecognitionDataset(return_cuts=True)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=1,
|
||||
persistent_workers=False,
|
||||
)
|
||||
return train_dl
|
||||
|
||||
|
||||
def test_one_hot():
|
||||
a = [1, 3, 2]
|
||||
b = [1, 0, 4, 2]
|
||||
c = [torch.tensor(a), torch.tensor(b)]
|
||||
d = pad_sequence(c, batch_first=True, padding_value=0)
|
||||
f = torch.nn.functional.one_hot(d, num_classes=5)
|
||||
e = (1 - f) * -10.0
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-10, 0, -10, -10, -10],
|
||||
[-10, -10, -10, 0, -10],
|
||||
[-10, -10, 0, -10, -10],
|
||||
[0, -10, -10, -10, -10],
|
||||
],
|
||||
[
|
||||
[-10, 0, -10, -10, -10],
|
||||
[0, -10, -10, -10, -10],
|
||||
[-10, -10, -10, -10, 0],
|
||||
[-10, -10, 0, -10, -10],
|
||||
],
|
||||
]
|
||||
).to(e.dtype)
|
||||
assert torch.all(torch.eq(e, expected))
|
||||
|
||||
|
||||
def test():
|
||||
"""
|
||||
The purpose of this test is to show that we can use pre-computed
|
||||
alignments to construct a mask, adding it to a randomly generated
|
||||
nnet_output, to decode the correct transcript from the resulting
|
||||
nnet_output.
|
||||
"""
|
||||
if not data_exists():
|
||||
return
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
dl = get_dataloader()
|
||||
|
||||
subsampling_factor, ali = load_alignments(ali_filename)
|
||||
ali = convert_alignments_to_tensor(ali, device=device)
|
||||
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
word_table = lexicon.word_table
|
||||
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
|
||||
for batch in dl:
|
||||
features = batch["inputs"]
|
||||
supervisions = batch["supervisions"]
|
||||
N = features.shape[0]
|
||||
T = features.shape[1] // subsampling_factor
|
||||
nnet_output = (
|
||||
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
|
||||
.softmax(dim=-1)
|
||||
.log()
|
||||
)
|
||||
cut_ids = [cut.id for cut in supervisions["cut"]]
|
||||
mask = lookup_alignments(
|
||||
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
|
||||
)
|
||||
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||
ali_model_scale = 0.8
|
||||
|
||||
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // subsampling_factor,
|
||||
supervisions["num_frames"] // subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
HLG=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=20,
|
||||
output_beam=8,
|
||||
min_active_states=30,
|
||||
max_active_states=10000,
|
||||
subsampling_factor=subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
hyps = [" ".join(s) for s in hyps]
|
||||
print(hyps)
|
||||
print(supervisions["text"])
|
||||
break
|
||||
|
||||
|
||||
def show_cut_ids():
|
||||
# The purpose of this function is to check that
|
||||
# for each utterance in the training set, there is
|
||||
# a corresponding alignment.
|
||||
#
|
||||
# After generating a1.txt and b1.txt
|
||||
# You can use
|
||||
# wc -l a1.txt b1.txt
|
||||
# which should show the same number of lines.
|
||||
#
|
||||
# cat a1.txt | sort | uniq > a11.txt
|
||||
# cat b1.txt | sort | uniq > b11.txt
|
||||
#
|
||||
# md5sum a11.txt b11.txt
|
||||
# which should show the identical hash
|
||||
#
|
||||
# diff a11.txt b11.txt
|
||||
# should print nothing
|
||||
|
||||
subsampling_factor, ali = load_alignments(ali_filename)
|
||||
with open("a1.txt", "w") as f:
|
||||
for key in ali:
|
||||
f.write(f"{key}\n")
|
||||
|
||||
# dl = get_dataloader()
|
||||
cuts_train = (
|
||||
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
|
||||
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
|
||||
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
|
||||
)
|
||||
|
||||
ans = []
|
||||
for cut in cuts_train:
|
||||
ans.append(cut.id)
|
||||
with open("b1.txt", "w") as f:
|
||||
for line in ans:
|
||||
f.write(f"{line}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
Loading…
x
Reference in New Issue
Block a user