mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Associate a cut with token alignment (without repeats) (#125)
* WIP: Associate a cut with token alignment (without repeats) * Save framewise alignments with/without repeats. * Minor fixes.
This commit is contained in:
parent
243fb9723c
commit
ec591698b0
@ -15,15 +15,29 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
./conformer_ctc/ali.py \
|
||||||
|
--exp-dir ./conformer_ctc/exp \
|
||||||
|
--lang-dir ./data/lang_bpe_500 \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--dataset train-clean-100 \
|
||||||
|
--out-dir data/ali
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
from lhotse import CutSet
|
||||||
|
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -34,7 +48,6 @@ from icefall.utils import (
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
get_alignments,
|
get_alignments,
|
||||||
save_alignments,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,10 +88,42 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ali-dir",
|
"--out-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/ali_500",
|
required=True,
|
||||||
help="The experiment dir",
|
help="""Output directory.
|
||||||
|
It contains 3 generated files:
|
||||||
|
|
||||||
|
- labels_xxx.h5
|
||||||
|
- aux_labels_xxx.h5
|
||||||
|
- cuts_xxx.json.gz
|
||||||
|
|
||||||
|
where xxx is the value of `--dataset`. For instance, if
|
||||||
|
`--dataset` is `train-clean-100`, it will contain 3 files:
|
||||||
|
|
||||||
|
- `labels_train-clean-100.h5`
|
||||||
|
- `aux_labels_train-clean-100.h5`
|
||||||
|
- `cuts_train-clean-100.json.gz`
|
||||||
|
|
||||||
|
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
|
||||||
|
alignment. The difference is that labels_xxx.h5 contains repeats.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="""The name of the dataset to compute alignments for.
|
||||||
|
Possible values are:
|
||||||
|
- test-clean.
|
||||||
|
- test-other
|
||||||
|
- train-clean-100
|
||||||
|
- train-clean-360
|
||||||
|
- train-other-500
|
||||||
|
- dev-clean
|
||||||
|
- dev-other
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -91,7 +136,9 @@ def get_params() -> AttributeDict:
|
|||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"num_decoder_layers": 6,
|
# Set it to 0 since attention decoder
|
||||||
|
# is not used for computing alignments
|
||||||
|
"num_decoder_layers": 0,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
"output_beam": 10,
|
"output_beam": 10,
|
||||||
@ -105,9 +152,11 @@ def get_params() -> AttributeDict:
|
|||||||
def compute_alignments(
|
def compute_alignments(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
|
labels_writer: FeaturesWriter,
|
||||||
|
aux_labels_writer: FeaturesWriter,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||||
) -> List[Tuple[str, List[int]]]:
|
) -> CutSet:
|
||||||
"""Compute the framewise alignments of a dataset.
|
"""Compute the framewise alignments of a dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -120,9 +169,10 @@ def compute_alignments(
|
|||||||
graph_compiler:
|
graph_compiler:
|
||||||
It converts token IDs to decoding graphs.
|
It converts token IDs to decoding graphs.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list of tuples. Each tuple contains two entries:
|
Return a CutSet. Each cut has two custom fields: labels_alignment
|
||||||
- Utterance ID
|
and aux_labels_alignment, containing framewise alignments information.
|
||||||
- Framewise alignments (token IDs) after subsampling
|
Both are of type `lhotse.array.TemporalArray`. The difference between
|
||||||
|
the two alignments is that `labels_alignment` contain repeats.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
num_batches = len(dl)
|
num_batches = len(dl)
|
||||||
@ -131,7 +181,7 @@ def compute_alignments(
|
|||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
|
||||||
device = graph_compiler.device
|
device = graph_compiler.device
|
||||||
ans = []
|
cuts = []
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
|
||||||
@ -140,11 +190,10 @@ def compute_alignments(
|
|||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
cut_list = supervisions["cut"]
|
||||||
|
|
||||||
cut_ids = []
|
for cut in cut_list:
|
||||||
for cut in supervisions["cut"]:
|
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
|
||||||
assert len(cut.supervisions) == 1
|
|
||||||
cut_ids.append(cut.id)
|
|
||||||
|
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||||
# nnet_output is [N, T, C]
|
# nnet_output is [N, T, C]
|
||||||
@ -156,7 +205,8 @@ def compute_alignments(
|
|||||||
# In general, new2old is an identity map since lhotse sorts the returned
|
# In general, new2old is an identity map since lhotse sorts the returned
|
||||||
# cuts by duration in descending order
|
# cuts by duration in descending order
|
||||||
new2old = supervision_segments[:, 0].tolist()
|
new2old = supervision_segments[:, 0].tolist()
|
||||||
cut_ids = [cut_ids[i] for i in new2old]
|
|
||||||
|
cut_list = [cut_list[i] for i in new2old]
|
||||||
|
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
token_ids = graph_compiler.texts_to_ids(texts)
|
||||||
decoding_graph = graph_compiler.compile(token_ids)
|
decoding_graph = graph_compiler.compile(token_ids)
|
||||||
@ -178,11 +228,32 @@ def compute_alignments(
|
|||||||
use_double_scores=params.use_double_scores,
|
use_double_scores=params.use_double_scores,
|
||||||
)
|
)
|
||||||
|
|
||||||
ali_ids = get_alignments(best_path)
|
labels_ali = get_alignments(best_path, kind="labels")
|
||||||
assert len(ali_ids) == len(cut_ids)
|
aux_labels_ali = get_alignments(best_path, kind="aux_labels")
|
||||||
ans += list(zip(cut_ids, ali_ids))
|
assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
|
||||||
|
for cut, labels, aux_labels in zip(
|
||||||
|
cut_list, labels_ali, aux_labels_ali
|
||||||
|
):
|
||||||
|
cut.labels_alignment = labels_writer.store_array(
|
||||||
|
key=cut.id,
|
||||||
|
value=np.asarray(labels, dtype=np.int32),
|
||||||
|
# frame shift is 0.01s, subsampling_factor is 4
|
||||||
|
frame_shift=0.04,
|
||||||
|
temporal_dim=0,
|
||||||
|
start=0,
|
||||||
|
)
|
||||||
|
cut.aux_labels_alignment = aux_labels_writer.store_array(
|
||||||
|
key=cut.id,
|
||||||
|
value=np.asarray(aux_labels, dtype=np.int32),
|
||||||
|
# frame shift is 0.01s, subsampling_factor is 4
|
||||||
|
frame_shift=0.04,
|
||||||
|
temporal_dim=0,
|
||||||
|
start=0,
|
||||||
|
)
|
||||||
|
|
||||||
num_cuts += len(ali_ids)
|
cuts += cut_list
|
||||||
|
|
||||||
|
num_cuts += len(cut_list)
|
||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
@ -191,7 +262,7 @@ def compute_alignments(
|
|||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ans
|
return CutSet.from_cuts(cuts)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -200,20 +271,35 @@ def main():
|
|||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert args.return_cuts is True
|
args.enable_spec_aug = False
|
||||||
assert args.concatenate_cuts is False
|
args.enable_musan = False
|
||||||
if args.full_libri is False:
|
args.return_cuts = True
|
||||||
print("Changing --full-libri to True")
|
args.concatenate_cuts = False
|
||||||
args.full_libri = True
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/ali")
|
setup_logger(f"{params.exp_dir}/log-ali")
|
||||||
|
|
||||||
logging.info("Computing alignment - started")
|
logging.info(f"Computing alignments for {params.dataset} - started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
out_dir = Path(params.out_dir)
|
||||||
|
out_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
|
||||||
|
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
|
||||||
|
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
|
||||||
|
|
||||||
|
for f in (
|
||||||
|
out_labels_ali_filename,
|
||||||
|
out_aux_labels_ali_filename,
|
||||||
|
out_manifest_filename,
|
||||||
|
):
|
||||||
|
if f.exists():
|
||||||
|
logging.info(f"{f} exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
max_token_id = max(lexicon.tokens)
|
max_token_id = max(lexicon.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
@ -221,6 +307,7 @@ def main():
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir,
|
||||||
@ -240,9 +327,12 @@ def main():
|
|||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
)
|
)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(
|
||||||
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
@ -250,60 +340,55 @@ def main():
|
|||||||
if start >= 0:
|
if start >= 0:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.load_state_dict(average_checkpoints(filenames))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
if params.dataset == "test-clean":
|
||||||
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
|
elif params.dataset == "test-other":
|
||||||
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
elif params.dataset == "train-clean-100":
|
||||||
|
train_clean_100_cuts = librispeech.train_clean_100_cuts()
|
||||||
|
dl = librispeech.train_dataloaders(train_clean_100_cuts)
|
||||||
|
elif params.dataset == "train-clean-360":
|
||||||
|
train_clean_360_cuts = librispeech.train_clean_360_cuts()
|
||||||
|
dl = librispeech.train_dataloaders(train_clean_360_cuts)
|
||||||
|
elif params.dataset == "train-other-500":
|
||||||
|
train_other_500_cuts = librispeech.train_other_500_cuts()
|
||||||
|
dl = librispeech.train_dataloaders(train_other_500_cuts)
|
||||||
|
elif params.dataset == "dev-clean":
|
||||||
|
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||||
|
dl = librispeech.valid_dataloaders(dev_clean_cuts)
|
||||||
|
else:
|
||||||
|
assert params.dataset == "dev-other", f"{params.dataset}"
|
||||||
|
dev_other_cuts = librispeech.dev_other_cuts()
|
||||||
|
dl = librispeech.valid_dataloaders(dev_other_cuts)
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders()
|
logging.info(f"Processing {params.dataset}")
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
|
||||||
test_dl = librispeech.test_dataloaders() # a list
|
with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
|
||||||
|
cut_set = compute_alignments(
|
||||||
ali_dir = Path(params.ali_dir)
|
|
||||||
ali_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
enabled_datasets = {
|
|
||||||
"test_clean": test_dl[0],
|
|
||||||
"test_other": test_dl[1],
|
|
||||||
"train-960": train_dl,
|
|
||||||
"valid": valid_dl,
|
|
||||||
}
|
|
||||||
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
|
|
||||||
# compute the alignments if you use --max-duration=500
|
|
||||||
#
|
|
||||||
# There are 960 * 3 = 2880 hours data and it takes only
|
|
||||||
# 3 hours 40 minutes to get the alignment.
|
|
||||||
# The RTF is roughly: 3.67 / 2880 = 0.0012743
|
|
||||||
#
|
|
||||||
# At the end, you would see
|
|
||||||
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
|
|
||||||
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
|
|
||||||
for name, dl in enabled_datasets.items():
|
|
||||||
logging.info(f"Processing {name}")
|
|
||||||
if name == "train-960":
|
|
||||||
logging.info(
|
|
||||||
f"It will take about 3 hours 40 minutes for {name}, "
|
|
||||||
"which contains 960 * 3 = 2880 hours of data"
|
|
||||||
)
|
|
||||||
alignments = compute_alignments(
|
|
||||||
model=model,
|
model=model,
|
||||||
dl=dl,
|
dl=dl,
|
||||||
|
labels_writer=labels_writer,
|
||||||
|
aux_labels_writer=aux_labels_writer,
|
||||||
params=params,
|
params=params,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
)
|
)
|
||||||
num_utt = len(alignments)
|
|
||||||
alignments = dict(alignments)
|
cut_set.to_file(out_manifest_filename)
|
||||||
assert num_utt == len(alignments)
|
|
||||||
filename = ali_dir / f"{name}.pt"
|
|
||||||
save_alignments(
|
|
||||||
alignments=alignments,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
filename=filename,
|
|
||||||
)
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"For dataset {name}, its alignments are saved to {filename}"
|
f"For dataset {params.dataset}, its alignments with repeats are "
|
||||||
|
f"saved to {out_labels_ali_filename}, the alignments without repeats "
|
||||||
|
f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
|
||||||
|
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -665,14 +665,17 @@ def main():
|
|||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
# CAUTION: `test_sets` is for displaying only.
|
|
||||||
# If you want to skip test-clean, you have to skip
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
# it inside the for loop. That is, use
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
#
|
|
||||||
# if test_set == 'test-clean': continue
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
#
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -618,8 +618,16 @@ def run(rank, world_size, args):
|
|||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
train_dl = librispeech.train_dataloaders()
|
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
|
if params.full_libri:
|
||||||
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -19,7 +19,6 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
@ -34,11 +33,10 @@ from lhotse.dataset import (
|
|||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from icefall.dataset.datamodule import DataModule
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
class LibriSpeechAsrDataModule(DataModule):
|
class LibriSpeechAsrDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for k2 ASR experiments.
|
DataModule for k2 ASR experiments.
|
||||||
It assumes there is always one train and valid dataloader,
|
It assumes there is always one train and valid dataloader,
|
||||||
@ -56,9 +54,11 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
This class should be derived for specific corpora used in ASR tasks.
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
super().add_arguments(parser)
|
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR data related options",
|
title="ASR data related options",
|
||||||
description="These options are used for the preparation of "
|
description="These options are used for the preparation of "
|
||||||
@ -74,7 +74,7 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
"Otherwise, use 100h subset.",
|
"Otherwise, use 100h subset.",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--feature-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path("data/fbank"),
|
default=Path("data/fbank"),
|
||||||
help="Path to directory with train/valid/test cuts.",
|
help="Path to directory with train/valid/test cuts.",
|
||||||
@ -154,17 +154,48 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
"collect the batches.",
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(self) -> DataLoader:
|
group.add_argument(
|
||||||
logging.info("About to get train cuts")
|
"--enable-spec-aug",
|
||||||
cuts_train = self.train_cuts()
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, use SpecAugment for training dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--spec-aug-time-warp-factor",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="Used only when --enable-spec-aug is True. "
|
||||||
|
"It specifies the factor for time warping in SpecAugment. "
|
||||||
|
"Larger values mean more warping. "
|
||||||
|
"A value less than 1 means to disable time warp.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--enable-musan",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
|
"with training dataset. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
cuts_musan = load_manifest(
|
||||||
|
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
transforms = []
|
||||||
|
if self.args.enable_musan:
|
||||||
|
logging.info("Enable MUSAN")
|
||||||
|
transforms.append(
|
||||||
|
CutMix(
|
||||||
|
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
logging.info("About to create train dataset")
|
|
||||||
transforms = [
|
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
|
||||||
]
|
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Using cut concatenation with duration factor "
|
f"Using cut concatenation with duration factor "
|
||||||
@ -179,15 +210,25 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
)
|
)
|
||||||
] + transforms
|
] + transforms
|
||||||
|
|
||||||
input_transforms = [
|
input_transforms = []
|
||||||
|
if self.args.enable_spec_aug:
|
||||||
|
logging.info("Enable SpecAugment")
|
||||||
|
logging.info(
|
||||||
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
|
)
|
||||||
|
input_transforms.append(
|
||||||
SpecAugment(
|
SpecAugment(
|
||||||
|
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||||
num_frame_masks=2,
|
num_frame_masks=2,
|
||||||
features_mask_size=27,
|
features_mask_size=27,
|
||||||
num_feature_masks=2,
|
num_feature_masks=2,
|
||||||
frames_mask_size=100,
|
frames_mask_size=100,
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Disable SpecAugment")
|
||||||
|
|
||||||
|
logging.info("About to create train dataset")
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
@ -243,10 +284,7 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
logging.info("About to get dev cuts")
|
|
||||||
cuts_valid = self.valid_cuts()
|
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
transforms = [
|
transforms = [
|
||||||
@ -285,25 +323,16 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
|
|
||||||
return valid_dl
|
return valid_dl
|
||||||
|
|
||||||
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
cuts = self.test_cuts()
|
|
||||||
is_list = isinstance(cuts, list)
|
|
||||||
test_loaders = []
|
|
||||||
if not is_list:
|
|
||||||
cuts = [cuts]
|
|
||||||
|
|
||||||
for cuts_test in cuts:
|
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
Fbank(FbankConfig(num_mel_bins=80))
|
|
||||||
)
|
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else PrecomputedFeatures(),
|
else PrecomputedFeatures(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = BucketingSampler(
|
sampler = BucketingSampler(
|
||||||
cuts_test, max_duration=self.args.max_duration, shuffle=False
|
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||||
)
|
)
|
||||||
logging.debug("About to create test dataloader")
|
logging.debug("About to create test dataloader")
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
@ -312,48 +341,45 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
test_loaders.append(test_dl)
|
return test_dl
|
||||||
|
|
||||||
if is_list:
|
|
||||||
return test_loaders
|
|
||||||
else:
|
|
||||||
return test_loaders[0]
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_clean_100_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train-clean-100 cuts")
|
||||||
cuts_train = load_manifest(
|
return load_manifest(
|
||||||
self.args.feature_dir / "cuts_train-clean-100.json.gz"
|
self.args.manifest_dir / "cuts_train-clean-100.json.gz"
|
||||||
)
|
)
|
||||||
if self.args.full_libri:
|
|
||||||
cuts_train = (
|
|
||||||
cuts_train
|
|
||||||
+ load_manifest(
|
|
||||||
self.args.feature_dir / "cuts_train-clean-360.json.gz"
|
|
||||||
)
|
|
||||||
+ load_manifest(
|
|
||||||
self.args.feature_dir / "cuts_train-other-500.json.gz"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return cuts_train
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def train_clean_360_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get train-clean-360 cuts")
|
||||||
cuts_valid = load_manifest(
|
return load_manifest(
|
||||||
self.args.feature_dir / "cuts_dev-clean.json.gz"
|
self.args.manifest_dir / "cuts_train-clean-360.json.gz"
|
||||||
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
|
)
|
||||||
return cuts_valid
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> List[CutSet]:
|
def train_other_500_cuts(self) -> CutSet:
|
||||||
test_sets = ["test-clean", "test-other"]
|
logging.info("About to get train-other-500 cuts")
|
||||||
cuts = []
|
return load_manifest(
|
||||||
for test_set in test_sets:
|
self.args.manifest_dir / "cuts_train-other-500.json.gz"
|
||||||
logging.debug("About to get test cuts")
|
|
||||||
cuts.append(
|
|
||||||
load_manifest(
|
|
||||||
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return cuts
|
@lru_cache()
|
||||||
|
def dev_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev-clean cuts")
|
||||||
|
return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev-other cuts")
|
||||||
|
return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-clean cuts")
|
||||||
|
return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-other cuts")
|
||||||
|
return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz")
|
||||||
|
@ -474,14 +474,17 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
# CAUTION: `test_sets` is for displaying only.
|
|
||||||
# If you want to skip test-clean, you have to skip
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
# it inside the for loop. That is, use
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
#
|
|
||||||
# if test_set == 'test-clean': continue
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
#
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -532,8 +532,16 @@ def run(rank, world_size, args):
|
|||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
train_dl = librispeech.train_dataloaders()
|
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
|
if params.full_libri:
|
||||||
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
@ -85,6 +85,7 @@ def load_checkpoint(
|
|||||||
optimizer: Optional[Optimizer] = None,
|
optimizer: Optional[Optimizer] = None,
|
||||||
scheduler: Optional[_LRScheduler] = None,
|
scheduler: Optional[_LRScheduler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional[GradScaler] = None,
|
||||||
|
strict: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
TODO: document it
|
TODO: document it
|
||||||
@ -101,9 +102,9 @@ def load_checkpoint(
|
|||||||
src_key = "{}.{}".format("module", key)
|
src_key = "{}.{}".format("module", key)
|
||||||
dst_state_dict[key] = src_state_dict.pop(src_key)
|
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||||
assert len(src_state_dict) == 0
|
assert len(src_state_dict) == 0
|
||||||
model.load_state_dict(dst_state_dict, strict=False)
|
model.load_state_dict(dst_state_dict, strict=strict)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=strict)
|
||||||
|
|
||||||
checkpoint.pop("model")
|
checkpoint.pop("model")
|
||||||
|
|
||||||
|
@ -224,8 +224,8 @@ def get_texts(
|
|||||||
return aux_labels.tolist()
|
return aux_labels.tolist()
|
||||||
|
|
||||||
|
|
||||||
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
|
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||||
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs.
|
"""Extract labels or aux_labels from the best-path FSAs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
best_paths:
|
best_paths:
|
||||||
@ -233,17 +233,34 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
|
|||||||
containing multiple FSAs, which is expected to be the result
|
containing multiple FSAs, which is expected to be the result
|
||||||
of k2.shortest_path (otherwise the returned values won't
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
be meaningful).
|
be meaningful).
|
||||||
|
kind:
|
||||||
|
Possible values are: "labels" and "aux_labels". Caution: When it is
|
||||||
|
"labels", the resulting alignments contain repeats.
|
||||||
Returns:
|
Returns:
|
||||||
Returns a list of lists of int, containing the token sequences we
|
Returns a list of lists of int, containing the token sequences we
|
||||||
decoded. For `ans[i]`, its length equals to the number of frames
|
decoded. For `ans[i]`, its length equals to the number of frames
|
||||||
after subsampling of the i-th utterance in the batch.
|
after subsampling of the i-th utterance in the batch.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
When `kind` is `labels`, one possible alignment example is (with
|
||||||
|
repeats)::
|
||||||
|
|
||||||
|
c c c blk a a blk blk t t t blk blk
|
||||||
|
|
||||||
|
If `kind` is `aux_labels`, the above example changes to::
|
||||||
|
|
||||||
|
c blk blk blk a blk blk blk t blk blk blk blk
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
assert kind in ("labels", "aux_labels")
|
||||||
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
|
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
|
||||||
label_shape = best_paths.arcs.shape().remove_axis(1)
|
token_shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
# label_shape has axes [fsa][arc]
|
# token_shape has axes [fsa][arc]
|
||||||
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous())
|
tokens = k2.RaggedTensor(
|
||||||
labels = labels.remove_values_eq(-1)
|
token_shape, getattr(best_paths, kind).contiguous()
|
||||||
return labels.tolist()
|
)
|
||||||
|
tokens = tokens.remove_values_eq(-1)
|
||||||
|
return tokens.tolist()
|
||||||
|
|
||||||
|
|
||||||
def save_alignments(
|
def save_alignments(
|
||||||
|
184
test/test_ali.py
184
test/test_ali.py
@ -25,199 +25,65 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import k2
|
from lhotse import CutSet, load_manifest
|
||||||
import torch
|
|
||||||
from lhotse import load_manifest
|
|
||||||
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from lhotse.dataset.collation import collate_custom_field
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from icefall.ali import (
|
|
||||||
convert_alignments_to_tensor,
|
|
||||||
load_alignments,
|
|
||||||
lookup_alignments,
|
|
||||||
)
|
|
||||||
from icefall.decode import get_lattice, one_best_decoding
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import get_texts
|
|
||||||
|
|
||||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||||
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
|
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
|
||||||
lang_dir = egs_dir / "data/lang_bpe_500"
|
lang_dir = egs_dir / "data/lang_bpe_500"
|
||||||
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
|
cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz"
|
||||||
# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
|
|
||||||
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
|
|
||||||
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
|
|
||||||
|
|
||||||
cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
|
|
||||||
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
|
|
||||||
|
|
||||||
|
|
||||||
def data_exists():
|
def data_exists():
|
||||||
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
|
return cuts_json.exists() and lang_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
def get_dataloader():
|
def get_dataloader():
|
||||||
cuts_train = load_manifest(cut_json)
|
cuts = load_manifest(cuts_json)
|
||||||
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
|
print(cuts[0])
|
||||||
train_sampler = SingleCutSampler(
|
cuts = cuts.with_features_path_prefix(egs_dir)
|
||||||
cuts_train,
|
sampler = SingleCutSampler(
|
||||||
max_duration=40,
|
cuts,
|
||||||
|
max_duration=10,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
train = K2SpeechRecognitionDataset(return_cuts=True)
|
dataset = K2SpeechRecognitionDataset(return_cuts=True)
|
||||||
|
|
||||||
train_dl = DataLoader(
|
dl = DataLoader(
|
||||||
train,
|
dataset,
|
||||||
sampler=train_sampler,
|
sampler=sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
persistent_workers=False,
|
persistent_workers=False,
|
||||||
)
|
)
|
||||||
return train_dl
|
return dl
|
||||||
|
|
||||||
|
|
||||||
def test_one_hot():
|
|
||||||
a = [1, 3, 2]
|
|
||||||
b = [1, 0, 4, 2]
|
|
||||||
c = [torch.tensor(a), torch.tensor(b)]
|
|
||||||
d = pad_sequence(c, batch_first=True, padding_value=0)
|
|
||||||
f = torch.nn.functional.one_hot(d, num_classes=5)
|
|
||||||
e = (1 - f) * -10.0
|
|
||||||
expected = torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[-10, 0, -10, -10, -10],
|
|
||||||
[-10, -10, -10, 0, -10],
|
|
||||||
[-10, -10, 0, -10, -10],
|
|
||||||
[0, -10, -10, -10, -10],
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-10, 0, -10, -10, -10],
|
|
||||||
[0, -10, -10, -10, -10],
|
|
||||||
[-10, -10, -10, -10, 0],
|
|
||||||
[-10, -10, 0, -10, -10],
|
|
||||||
],
|
|
||||||
]
|
|
||||||
).to(e.dtype)
|
|
||||||
assert torch.all(torch.eq(e, expected))
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
"""
|
|
||||||
The purpose of this test is to show that we can use pre-computed
|
|
||||||
alignments to construct a mask, adding it to a randomly generated
|
|
||||||
nnet_output, to decode the correct transcript from the resulting
|
|
||||||
nnet_output.
|
|
||||||
"""
|
|
||||||
if not data_exists():
|
if not data_exists():
|
||||||
return
|
return
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
dl = get_dataloader()
|
dl = get_dataloader()
|
||||||
|
|
||||||
subsampling_factor, ali = load_alignments(ali_filename)
|
|
||||||
ali = convert_alignments_to_tensor(ali, device=device)
|
|
||||||
|
|
||||||
lexicon = Lexicon(lang_dir)
|
|
||||||
max_token_id = max(lexicon.tokens)
|
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
word_table = lexicon.word_table
|
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(
|
|
||||||
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
for batch in dl:
|
for batch in dl:
|
||||||
features = batch["inputs"]
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
N = features.shape[0]
|
cuts = supervisions["cut"]
|
||||||
T = features.shape[1] // subsampling_factor
|
labels_alignment, labels_alignment_length = collate_custom_field(
|
||||||
nnet_output = (
|
CutSet.from_cuts(cuts), "labels_alignment"
|
||||||
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
|
|
||||||
.softmax(dim=-1)
|
|
||||||
.log()
|
|
||||||
)
|
)
|
||||||
cut_ids = [cut.id for cut in supervisions["cut"]]
|
|
||||||
mask = lookup_alignments(
|
|
||||||
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
|
|
||||||
)
|
|
||||||
min_len = min(nnet_output.shape[1], mask.shape[1])
|
|
||||||
ali_model_scale = 0.8
|
|
||||||
|
|
||||||
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
aux_labels_alignment,
|
||||||
supervisions["start_frame"] // subsampling_factor,
|
aux_labels_alignment_length,
|
||||||
supervisions["num_frames"] // subsampling_factor,
|
) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment")
|
||||||
),
|
|
||||||
1,
|
|
||||||
).to(torch.int32)
|
|
||||||
|
|
||||||
lattice = get_lattice(
|
print(labels_alignment)
|
||||||
nnet_output=nnet_output,
|
print(aux_labels_alignment)
|
||||||
decoding_graph=HLG,
|
print(labels_alignment_length)
|
||||||
supervision_segments=supervision_segments,
|
print(aux_labels_alignment_length)
|
||||||
search_beam=20,
|
|
||||||
output_beam=8,
|
|
||||||
min_active_states=30,
|
|
||||||
max_active_states=10000,
|
|
||||||
subsampling_factor=subsampling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
||||||
hyps = [" ".join(s) for s in hyps]
|
|
||||||
print(hyps)
|
|
||||||
print(supervisions["text"])
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def show_cut_ids():
|
|
||||||
# The purpose of this function is to check that
|
|
||||||
# for each utterance in the training set, there is
|
|
||||||
# a corresponding alignment.
|
|
||||||
#
|
|
||||||
# After generating a1.txt and b1.txt
|
|
||||||
# You can use
|
|
||||||
# wc -l a1.txt b1.txt
|
|
||||||
# which should show the same number of lines.
|
|
||||||
#
|
|
||||||
# cat a1.txt | sort | uniq > a11.txt
|
|
||||||
# cat b1.txt | sort | uniq > b11.txt
|
|
||||||
#
|
|
||||||
# md5sum a11.txt b11.txt
|
|
||||||
# which should show the identical hash
|
|
||||||
#
|
|
||||||
# diff a11.txt b11.txt
|
|
||||||
# should print nothing
|
|
||||||
|
|
||||||
subsampling_factor, ali = load_alignments(ali_filename)
|
|
||||||
with open("a1.txt", "w") as f:
|
|
||||||
for key in ali:
|
|
||||||
f.write(f"{key}\n")
|
|
||||||
|
|
||||||
# dl = get_dataloader()
|
|
||||||
cuts_train = (
|
|
||||||
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
|
|
||||||
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
|
|
||||||
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
|
|
||||||
)
|
|
||||||
|
|
||||||
ans = []
|
|
||||||
for cut in cuts_train:
|
|
||||||
ans.append(cut.id)
|
|
||||||
with open("b1.txt", "w") as f:
|
|
||||||
for line in ans:
|
|
||||||
f.write(f"{line}\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test()
|
test()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user