mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use pre-computed alignments in LF-MMI training.
This commit is contained in:
parent
9e6bd0f07c
commit
94daaee6ba
@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
|
transforms = [
|
||||||
|
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||||
|
]
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Using cut concatenation with duration factor "
|
f"Using cut concatenation with duration factor "
|
||||||
|
@ -21,7 +21,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -36,6 +36,11 @@ from torch.nn.utils import clip_grad_norm_
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam
|
||||||
|
|
||||||
|
from icefall.ali import (
|
||||||
|
convert_alignments_to_tensor,
|
||||||
|
load_alignments,
|
||||||
|
lookup_alignments,
|
||||||
|
)
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
@ -93,6 +98,17 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ali-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/ali_500",
|
||||||
|
help="""This folder is expected to contain
|
||||||
|
two files, train-960.pt and valid.pt, which
|
||||||
|
contain framewise alignment information for
|
||||||
|
the training set and validation set.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -284,6 +300,7 @@ def compute_loss(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: MmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
|
ali: Optional[Dict[str, torch.Tensor]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute LF-MMI loss given the model and its inputs.
|
Compute LF-MMI loss given the model and its inputs.
|
||||||
@ -304,6 +321,8 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
|
ali:
|
||||||
|
Precomputed alignments.
|
||||||
"""
|
"""
|
||||||
device = graph_compiler.device
|
device = graph_compiler.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -323,6 +342,30 @@ def compute_loss(
|
|||||||
supervisions, subsampling_factor=params.subsampling_factor
|
supervisions, subsampling_factor=params.subsampling_factor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ali is not None and params.batch_idx_train < 4000:
|
||||||
|
cut_ids = [cut.id for cut in supervisions["cut"]]
|
||||||
|
|
||||||
|
# As encode_supervisions reorders cuts, we need
|
||||||
|
# also to reorder cut IDs here
|
||||||
|
new2old = supervision_segments[:, 0].tolist()
|
||||||
|
cut_ids = [cut_ids[i] for i in new2old]
|
||||||
|
|
||||||
|
# Check that new2old is just a permutation,
|
||||||
|
# i.e., each cut contains only one utterance
|
||||||
|
new2old.sort()
|
||||||
|
assert new2old == torch.arange(len(new2old)).tolist()
|
||||||
|
mask = lookup_alignments(
|
||||||
|
cut_ids=cut_ids,
|
||||||
|
alignments=ali,
|
||||||
|
num_classes=nnet_output.shape[2],
|
||||||
|
).to(nnet_output)
|
||||||
|
|
||||||
|
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||||
|
ali_scale = 500.0 / (params.batch_idx_train + 500)
|
||||||
|
|
||||||
|
nnet_output = nnet_output.clone()
|
||||||
|
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]
|
||||||
|
|
||||||
loss_fn = LFMMILoss(
|
loss_fn = LFMMILoss(
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
use_pruned_intersect=params.use_pruned_intersect,
|
use_pruned_intersect=params.use_pruned_intersect,
|
||||||
@ -377,6 +420,7 @@ def compute_validation_loss(
|
|||||||
graph_compiler: MmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
|
ali: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the validation process. The validation loss
|
"""Run the validation process. The validation loss
|
||||||
is saved in `params.valid_loss`.
|
is saved in `params.valid_loss`.
|
||||||
@ -394,6 +438,7 @@ def compute_validation_loss(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
|
ali=ali,
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
assert mmi_loss.requires_grad is False
|
assert mmi_loss.requires_grad is False
|
||||||
@ -435,6 +480,8 @@ def train_one_epoch(
|
|||||||
graph_compiler: MmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
train_ali: Optional[Dict[str, torch.Tensor]],
|
||||||
|
valid_ali: Optional[Dict[str, torch.Tensor]],
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -457,6 +504,10 @@ def train_one_epoch(
|
|||||||
Dataloader for the training dataset.
|
Dataloader for the training dataset.
|
||||||
valid_dl:
|
valid_dl:
|
||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
|
train_ali:
|
||||||
|
Precomputed alignments for the training set.
|
||||||
|
valid_ali:
|
||||||
|
Precomputed alignments for the validation set.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
world_size:
|
world_size:
|
||||||
@ -481,6 +532,7 @@ def train_one_epoch(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
ali=train_ali,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
@ -565,6 +617,7 @@ def train_one_epoch(
|
|||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
ali=valid_ali,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -673,12 +726,34 @@ def run(rank, world_size, args):
|
|||||||
if checkpoints:
|
if checkpoints:
|
||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
|
train_960_ali_filename = Path(params.ali_dir) / "train-960.pt"
|
||||||
|
if params.batch_idx_train < 4000 and train_960_ali_filename.is_file():
|
||||||
|
logging.info("Use pre-computed alignments")
|
||||||
|
subsampling_factor, train_ali = load_alignments(train_960_ali_filename)
|
||||||
|
assert subsampling_factor == params.subsampling_factor
|
||||||
|
assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723"
|
||||||
|
|
||||||
|
valid_ali_filename = Path(params.ali_dir) / "valid.pt"
|
||||||
|
subsampling_factor, valid_ali = load_alignments(valid_ali_filename)
|
||||||
|
assert subsampling_factor == params.subsampling_factor
|
||||||
|
|
||||||
|
train_ali = convert_alignments_to_tensor(train_ali, device=device)
|
||||||
|
valid_ali = convert_alignments_to_tensor(valid_ali, device=device)
|
||||||
|
else:
|
||||||
|
logging.info("Not using alignments")
|
||||||
|
train_ali = None
|
||||||
|
valid_ali = None
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
train_dl = librispeech.train_dataloaders()
|
train_dl = librispeech.train_dataloaders()
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
if params.batch_idx_train > 4000 and train_ali is not None:
|
||||||
|
# Delete the alignments to save memory
|
||||||
|
train_ali = None
|
||||||
|
valid_ali = None
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -699,6 +774,8 @@ def run(rank, world_size, args):
|
|||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
|
train_ali=train_ali,
|
||||||
|
valid_ali=valid_ali,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
142
icefall/ali.py
Normal file
142
icefall/ali.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def save_alignments(
|
||||||
|
alignments: Dict[str, List[int]],
|
||||||
|
subsampling_factor: int,
|
||||||
|
filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Save alignments to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignments:
|
||||||
|
A dict containing alignments. Keys of the dict are utterances and
|
||||||
|
values are the corresponding framewise alignments after subsampling.
|
||||||
|
subsampling_factor:
|
||||||
|
The subsampling factor of the model.
|
||||||
|
filename:
|
||||||
|
Path to save the alignments.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
ali_dict = {
|
||||||
|
"subsampling_factor": subsampling_factor,
|
||||||
|
"alignments": alignments,
|
||||||
|
}
|
||||||
|
torch.save(ali_dict, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
||||||
|
"""Load alignments from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Path to the file containing alignment information.
|
||||||
|
The file should be saved by :func:`save_alignments`.
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- subsampling_factor: The subsampling_factor used to compute
|
||||||
|
the alignments.
|
||||||
|
- alignments: A dict containing utterances and their corresponding
|
||||||
|
framewise alignment, after subsampling.
|
||||||
|
"""
|
||||||
|
ali_dict = torch.load(filename)
|
||||||
|
subsampling_factor = ali_dict["subsampling_factor"]
|
||||||
|
alignments = ali_dict["alignments"]
|
||||||
|
return subsampling_factor, alignments
|
||||||
|
|
||||||
|
|
||||||
|
def convert_alignments_to_tensor(
|
||||||
|
alignments: Dict[str, List[int]], device: torch.device
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Convert alignments from list of int to a 1-D torch.Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignments:
|
||||||
|
A dict containing alignments. Keys are utterance IDs and
|
||||||
|
values are their corresponding frame-wise alignments.
|
||||||
|
device:
|
||||||
|
The device to move the alignments to.
|
||||||
|
Returns:
|
||||||
|
Return a dict using 1-D torch.Tensor to store the alignments.
|
||||||
|
The dtype of the tensor are `torch.int64`. We choose `torch.int64`
|
||||||
|
because `torch.nn.functional.one_hot` requires that.
|
||||||
|
"""
|
||||||
|
ans = {}
|
||||||
|
for utt_id, ali in alignments.items():
|
||||||
|
ali = torch.tensor(ali, dtype=torch.int64, device=device)
|
||||||
|
ans[utt_id] = ali
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_alignments(
|
||||||
|
cut_ids: List[str],
|
||||||
|
alignments: Dict[str, torch.Tensor],
|
||||||
|
num_classes: int,
|
||||||
|
log_score: float = -10,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Return a mask constructed from alignments by a list of cut IDs.
|
||||||
|
|
||||||
|
The returned mask is a 3-D tensor of shape (N, T, C). For each frame,
|
||||||
|
i.e., each row, of the returned mask, positions not corresponding to
|
||||||
|
the alignments are filled with `log_score`, while the position
|
||||||
|
specified by the alignment is filled with 0. For instance, if the alignments
|
||||||
|
of two utterances are:
|
||||||
|
|
||||||
|
[ [1, 3, 2], [1, 0, 4, 2] ]
|
||||||
|
num_classes is 5 and log_score is -10, then the returned mask is
|
||||||
|
|
||||||
|
[
|
||||||
|
[[-10, 0, -10, -10, -10],
|
||||||
|
[-10, -10, -10, 0, -10],
|
||||||
|
[-10, -10, 0, -10, -10],
|
||||||
|
[0, -10, -10, -10, -10]],
|
||||||
|
[[-10, 0, -10, -10, -10],
|
||||||
|
[0, -10, -10, -10, -10],
|
||||||
|
[-10, -10, -10, -10, 0],
|
||||||
|
[-10, -10, 0, -10, -10]]
|
||||||
|
]
|
||||||
|
Note: We pad the alignment of the first utterance with 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cut_ids:
|
||||||
|
A list of utterance IDs.
|
||||||
|
alignments:
|
||||||
|
A dict containing alignments. The keys are utterance IDs and the values
|
||||||
|
are framewise alignments.
|
||||||
|
num_classes:
|
||||||
|
The max token ID + 1 that appears in the alignments.
|
||||||
|
log_score:
|
||||||
|
Positions in the returned tensor not corresponding to the alignments
|
||||||
|
are filled with this value.
|
||||||
|
Returns:
|
||||||
|
Return a 3-D torch.float32 tensor of shape (N, T, C).
|
||||||
|
"""
|
||||||
|
# We assume all utterances have their alignments.
|
||||||
|
ali = [alignments[cut_id] for cut_id in cut_ids]
|
||||||
|
padded_ali = pad_sequence(ali, batch_first=True, padding_value=0)
|
||||||
|
padded_one_hot = torch.nn.functional.one_hot(
|
||||||
|
padded_ali,
|
||||||
|
num_classes=num_classes,
|
||||||
|
)
|
||||||
|
mask = (1 - padded_one_hot) * float(log_score)
|
||||||
|
return mask
|
223
test/test_ali.py
Executable file
223
test/test_ali.py
Executable file
@ -0,0 +1,223 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Runt his file using one of the following two ways:
|
||||||
|
# (1) python3 ./test/test_ali.py
|
||||||
|
# (2) pytest ./test/test_ali.py
|
||||||
|
|
||||||
|
# The purpose of this file is to show that if we build a mask
|
||||||
|
# from alignments and add it to a randomly generated nnet_output,
|
||||||
|
# we can decode the correct transcript.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from lhotse import load_manifest
|
||||||
|
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.ali import (
|
||||||
|
convert_alignments_to_tensor,
|
||||||
|
load_alignments,
|
||||||
|
lookup_alignments,
|
||||||
|
)
|
||||||
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
|
||||||
|
lang_dir = egs_dir / "data/lang_bpe_500"
|
||||||
|
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
|
||||||
|
cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
|
||||||
|
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
|
||||||
|
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
|
||||||
|
|
||||||
|
# cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
|
||||||
|
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
|
||||||
|
|
||||||
|
|
||||||
|
def data_exists():
|
||||||
|
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataloader():
|
||||||
|
cuts_train = load_manifest(cut_json)
|
||||||
|
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
|
||||||
|
train_sampler = SingleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=200,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
train = K2SpeechRecognitionDataset(return_cuts=True)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=1,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
|
||||||
|
def test_one_hot():
|
||||||
|
a = [1, 3, 2]
|
||||||
|
b = [1, 0, 4, 2]
|
||||||
|
c = [torch.tensor(a), torch.tensor(b)]
|
||||||
|
d = pad_sequence(c, batch_first=True, padding_value=0)
|
||||||
|
f = torch.nn.functional.one_hot(d, num_classes=5)
|
||||||
|
e = (1 - f) * -10.0
|
||||||
|
expected = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-10, 0, -10, -10, -10],
|
||||||
|
[-10, -10, -10, 0, -10],
|
||||||
|
[-10, -10, 0, -10, -10],
|
||||||
|
[0, -10, -10, -10, -10],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-10, 0, -10, -10, -10],
|
||||||
|
[0, -10, -10, -10, -10],
|
||||||
|
[-10, -10, -10, -10, 0],
|
||||||
|
[-10, -10, 0, -10, -10],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
).to(e.dtype)
|
||||||
|
assert torch.all(torch.eq(e, expected))
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""
|
||||||
|
The purpose of this test is to show that we can use pre-computed
|
||||||
|
alignments to construct a mask, adding it to a randomly generated
|
||||||
|
nnet_output, to decode the correct transcript from the resulting
|
||||||
|
nnet_output.
|
||||||
|
"""
|
||||||
|
if not data_exists():
|
||||||
|
return
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
dl = get_dataloader()
|
||||||
|
|
||||||
|
subsampling_factor, ali = load_alignments(ali_filename)
|
||||||
|
ali = convert_alignments_to_tensor(ali, device=device)
|
||||||
|
|
||||||
|
lexicon = Lexicon(lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
word_table = lexicon.word_table
|
||||||
|
|
||||||
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch in dl:
|
||||||
|
features = batch["inputs"]
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
N = features.shape[0]
|
||||||
|
T = features.shape[1] // subsampling_factor
|
||||||
|
nnet_output = (
|
||||||
|
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
|
||||||
|
.softmax(dim=-1)
|
||||||
|
.log()
|
||||||
|
)
|
||||||
|
cut_ids = [cut.id for cut in supervisions["cut"]]
|
||||||
|
mask = lookup_alignments(
|
||||||
|
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
|
||||||
|
)
|
||||||
|
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||||
|
ali_model_scale = 0.8
|
||||||
|
|
||||||
|
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
(
|
||||||
|
supervisions["sequence_idx"],
|
||||||
|
supervisions["start_frame"] // subsampling_factor,
|
||||||
|
supervisions["num_frames"] // subsampling_factor,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
).to(torch.int32)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
HLG=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=20,
|
||||||
|
output_beam=8,
|
||||||
|
min_active_states=30,
|
||||||
|
max_active_states=10000,
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
hyps = [" ".join(s) for s in hyps]
|
||||||
|
print(hyps)
|
||||||
|
print(supervisions["text"])
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def show_cut_ids():
|
||||||
|
# The purpose of this function is to check that
|
||||||
|
# for each utterance in the training set, there is
|
||||||
|
# a corresponding alignment.
|
||||||
|
#
|
||||||
|
# After generating a1.txt and b1.txt
|
||||||
|
# You can use
|
||||||
|
# wc -l a1.txt b1.txt
|
||||||
|
# which should show the same number of lines.
|
||||||
|
#
|
||||||
|
# cat a1.txt | sort | uniq > a11.txt
|
||||||
|
# cat b1.txt | sort | uniq > b11.txt
|
||||||
|
#
|
||||||
|
# md5sum a11.txt b11.txt
|
||||||
|
# which should show the identical hash
|
||||||
|
#
|
||||||
|
# diff a11.txt b11.txt
|
||||||
|
# should print nothing
|
||||||
|
|
||||||
|
subsampling_factor, ali = load_alignments(ali_filename)
|
||||||
|
with open("a1.txt", "w") as f:
|
||||||
|
for key in ali:
|
||||||
|
f.write(f"{key}\n")
|
||||||
|
|
||||||
|
# dl = get_dataloader()
|
||||||
|
cuts_train = (
|
||||||
|
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
|
||||||
|
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
|
||||||
|
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
|
||||||
|
)
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for cut in cuts_train:
|
||||||
|
ans.append(cut.id)
|
||||||
|
with open("b1.txt", "w") as f:
|
||||||
|
for line in ans:
|
||||||
|
f.write(f"{line}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test()
|
Loading…
x
Reference in New Issue
Block a user