mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Compute framewise alignment information of the LibriSpeech dataset.
This commit is contained in:
parent
4580ff10df
commit
27a6d5e9cb
@ -18,6 +18,7 @@
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -32,6 +33,7 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
get_alignments,
|
||||
save_alignments,
|
||||
setup_logger,
|
||||
)
|
||||
|
||||
@ -56,14 +58,33 @@ def get_parser():
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ali-dir",
|
||||
type=str,
|
||||
default="data/ali",
|
||||
help="The experiment dir",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
@ -71,8 +92,6 @@ def get_params() -> AttributeDict:
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"is_espnet_structure": True,
|
||||
"mmi_loss": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"output_beam": 10,
|
||||
"use_double_scores": True,
|
||||
@ -86,9 +105,31 @@ def compute_alignments(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
token_table: k2.SymbolTable,
|
||||
):
|
||||
) -> List[Tuple[str, List[int]]]:
|
||||
"""Compute the framewise alignments of a dataset.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The neural network model.
|
||||
dl:
|
||||
Dataloader containing the dataset.
|
||||
params:
|
||||
Parameters for computing alignments.
|
||||
graph_compiler:
|
||||
It converts token IDs to decoding graphs.
|
||||
Returns:
|
||||
Return a list of tuples. Each tuple contains two entries:
|
||||
- Utterance ID
|
||||
- Framewise alignments (token IDs) after subsampling
|
||||
"""
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
num_cuts = 0
|
||||
|
||||
device = graph_compiler.device
|
||||
ans = []
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
feature = batch["inputs"]
|
||||
|
||||
@ -97,11 +138,23 @@ def compute_alignments(
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
cut_ids = []
|
||||
for cut in supervisions["cut"]:
|
||||
assert len(cut.supervisions) == 1
|
||||
cut_ids.append(cut.id)
|
||||
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, T, C]
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
# we need also to sort cut_ids as encode_supervisions()
|
||||
# reorders "texts".
|
||||
# In general, new2old is an identity map since lhotse sorts the returned
|
||||
# cuts by duration in descending order
|
||||
new2old = supervision_segments[:, 0].tolist()
|
||||
cut_ids = [cut_ids[i] for i in new2old]
|
||||
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
@ -113,22 +166,30 @@ def compute_alignments(
|
||||
)
|
||||
|
||||
lattice = k2.intersect_dense(
|
||||
decoding_graph, dense_fsa_vec, params.output_beam
|
||||
decoding_graph,
|
||||
dense_fsa_vec,
|
||||
params.output_beam,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
lattice=lattice,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
ali_ids = get_alignments(best_path)
|
||||
ali_tokens = [[token_table[i] for i in ids] for ids in ali_ids]
|
||||
assert len(ali_ids) == len(cut_ids)
|
||||
ans += list(zip(cut_ids, ali_ids))
|
||||
|
||||
frame_shift = 0.01 # 10ms, i.e., 0.01 seconds
|
||||
for i, ali in enumerate(ali_tokens[0]):
|
||||
print(i * params.subsampling_factor * frame_shift, ali)
|
||||
import sys
|
||||
num_cuts += len(ali_ids)
|
||||
|
||||
sys.exit(0)
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -138,6 +199,7 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.return_cuts is True
|
||||
assert args.concatenate_cuts is False
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
@ -169,9 +231,7 @@ def main():
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=False,
|
||||
is_espnet_structure=params.is_espnet_structure,
|
||||
mmi_loss=params.mmi_loss,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
@ -190,20 +250,40 @@ def main():
|
||||
model.eval()
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
test_dl = librispeech.test_dataloaders() # a list
|
||||
|
||||
ali_dir = Path(params.ali_dir)
|
||||
ali_dir.mkdir(exist_ok=True)
|
||||
|
||||
enabled_datasets = {
|
||||
"test_clean": test_dl[0],
|
||||
"test_other": test_dl[1],
|
||||
"train-960": train_dl,
|
||||
"valid": valid_dl,
|
||||
}
|
||||
|
||||
compute_alignments(
|
||||
model=model,
|
||||
dl=enabled_datasets["test_clean"],
|
||||
params=params,
|
||||
graph_compiler=graph_compiler,
|
||||
token_table=lexicon.token_table,
|
||||
)
|
||||
for name, dl in enabled_datasets.items():
|
||||
logging.info(f"Processing {name}")
|
||||
alignments = compute_alignments(
|
||||
model=model,
|
||||
dl=dl,
|
||||
params=params,
|
||||
graph_compiler=graph_compiler,
|
||||
)
|
||||
num_utt = len(alignments)
|
||||
alignments = dict(alignments)
|
||||
assert num_utt == len(alignments)
|
||||
filename = ali_dir / f"{name}.pt"
|
||||
save_alignments(
|
||||
alignments=alignments,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
filename=filename,
|
||||
)
|
||||
logging.info(
|
||||
f"For dataset {name}, its alignments are saved to {filename}"
|
||||
)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
@ -325,6 +325,51 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
return labels.tolist()
|
||||
|
||||
|
||||
def save_alignments(
|
||||
alignments: Dict[str, List[int]],
|
||||
subsampling_factor: int,
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""Save alignments to a file.
|
||||
|
||||
Args:
|
||||
alignments:
|
||||
A dict containing alignments. Keys of the dict are utterances and
|
||||
values are the corresponding framewise alignments after subsampling.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
filename:
|
||||
Path to save the alignments.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
ali_dict = {
|
||||
"subsampling_factor": subsampling_factor,
|
||||
"alignments": alignments,
|
||||
}
|
||||
torch.save(ali_dict, filename)
|
||||
|
||||
|
||||
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
||||
"""Load alignments from a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Path to the file containing alignment information.
|
||||
The file should be saved by :func:`save_alignments`.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- subsampling_factor: The subsampling_factor used to compute
|
||||
the alignments.
|
||||
- alignments: A dict containing utterances and their corresponding
|
||||
framewise alignment, after subsampling.
|
||||
"""
|
||||
ali_dict = torch.load(filename)
|
||||
subsampling_factor = ali_dict["subsampling_factor"]
|
||||
alignments = ali_dict["alignments"]
|
||||
return subsampling_factor, alignments
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
||||
) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user