Save framewise alignments with/without repeats.

This commit is contained in:
Fangjun Kuang 2021-11-19 15:32:52 +08:00
parent 62ada37d4e
commit c17527433d
4 changed files with 89 additions and 65 deletions

View File

@ -24,13 +24,12 @@ Usage:
--avg 10 \
--max-duration 300 \
--dataset train-clean-100 \
--out-dir data/token-ali
--out-dir data/ali
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import numpy as np
@ -49,7 +48,6 @@ from icefall.utils import (
encode_supervisions,
get_alignments,
get_env_info,
save_alignments,
setup_logger,
)
@ -94,16 +92,21 @@ def get_parser():
type=str,
required=True,
help="""Output directory.
It contains the following generated files:
It contains 3 generated files:
- xxx.h5
- labels_xxx.h5
- aux_labels_xxx.h5
- cuts_xxx.json.gz
where xxx is the value of `--dataset`. For instance, if
`--dataset` is `train-clean-100`, it will contain two files:
`--dataset` is `train-clean-100`, it will contain 3 files:
- `train-clean-100.h5`
- `labels_train-clean-100.h5`
- `aux_labels_train-clean-100.h5`
- `cuts_train-clean-100.json.gz`
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
alignment. The difference is that labels_xxx.h5 contains repeats.
""",
)
@ -149,7 +152,8 @@ def get_params() -> AttributeDict:
def compute_alignments(
model: torch.nn.Module,
dl: torch.utils.data.DataLoader,
writer: FeaturesWriter,
labels_writer: FeaturesWriter,
aux_labels_writer: FeaturesWriter,
params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler,
) -> CutSet:
@ -165,8 +169,10 @@ def compute_alignments(
graph_compiler:
It converts token IDs to decoding graphs.
Returns:
Return a CutSet. Each cut has a custom field `token_alignment`
of type `lhotse.array.TemporalArray`.
Return a CutSet. Each cut has two custom fields: labels_alignment
and aux_labels_alignment, containing framewise alignments information.
Both are of type `lhotse.array.TemporalArray`. The difference between
the two alignments is that `labels_alignment` contain repeats.
"""
try:
num_batches = len(dl)
@ -204,7 +210,6 @@ def compute_alignments(
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
decoding_graph.tokens = decoding_graph.aux_labels.clone()
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
@ -223,21 +228,32 @@ def compute_alignments(
use_double_scores=params.use_double_scores,
)
ali_ids = get_alignments(best_path)
assert len(ali_ids) == len(cut_list)
for cut, ali in zip(cut_list, ali_ids):
cut.token_alignment = writer.store_array(
labels_ali = get_alignments(best_path, kind="labels")
aux_labels_ali = get_alignments(best_path, kind="aux_labels")
assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
for cut, labels, aux_labels in zip(
cut_list, labels_ali, aux_labels_ali
):
cut.labels_alignment = labels_writer.store_array(
key=cut.id,
value=np.asarray(ali, dtype=np.int32),
frame_shift=0.04, # frame shift is 0.01s, subsampling_factor is 4
value=np.asarray(labels, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
cut.aux_labels_alignment = aux_labels_writer.store_array(
key=cut.id,
value=np.asarray(aux_labels, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
cuts += cut_list
num_cuts += len(ali_ids)
num_cuts += len(cut_list)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
@ -271,15 +287,18 @@ def main():
out_dir = Path(params.out_dir)
out_dir.mkdir(exist_ok=True)
out_ali_filename = out_dir / f"{params.dataset}.h5"
if out_ali_filename.exists():
logging.info(f"{out_ali_filename} exists - skipping")
return
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
if out_manifest_filename.exists():
logging.info(f"{out_manifest_filename} exists - skipping")
return
for f in (
out_labels_ali_filename,
out_aux_labels_ali_filename,
out_manifest_filename,
):
if f.exists():
logging.info(f"{f} exists - skipping")
return
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
@ -352,21 +371,24 @@ def main():
dl = librispeech.valid_dataloaders(dev_other_cuts)
logging.info(f"Processing {params.dataset}")
with NumpyHdf5Writer(out_ali_filename) as writer:
cut_set = compute_alignments(
model=model,
dl=dl,
writer=writer,
params=params,
graph_compiler=graph_compiler,
)
with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
cut_set = compute_alignments(
model=model,
dl=dl,
labels_writer=labels_writer,
aux_labels_writer=aux_labels_writer,
params=params,
graph_compiler=graph_compiler,
)
cut_set.to_json(out_manifest_filename)
cut_set.to_file(out_manifest_filename)
logging.info(
f"For dataset {params.dataset}, its alignments are "
f"saved to {out_ali_filename} and the cut manifest file "
f"is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
f"For dataset {params.dataset}, its alignments with repeats are "
f"saved to {out_labels_ali_filename}, the alignments without repeats "
f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
)

View File

@ -19,7 +19,6 @@ import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (

View File

@ -305,11 +305,8 @@ def get_texts(
return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
"""Extract the token IDs (from best_paths.tokens) from the best-path FSAs.
Caution:
There are no repeats in `best_paths.tokens`.
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
"""Extract labels or aux_labels from the best-path FSAs.
Args:
best_paths:
@ -317,15 +314,21 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
kind:
Possible values are: "labels" and "aux_labels". Caution: When it is
"labels", the resulting alignments contain repeats.
Returns:
Returns a list of lists of int, containing the token sequences we
decoded. For `ans[i]`, its length equals to the number of frames
after subsampling of the i-th utterance in the batch.
"""
assert kind in ("labels", "aux_labels")
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
token_shape = best_paths.arcs.shape().remove_axis(1)
# token_shape has axes [fsa][arc]
tokens = k2.RaggedTensor(token_shape, best_paths.tokens)
tokens = k2.RaggedTensor(
token_shape, getattr(best_paths, kind).contiguous()
)
tokens = tokens.remove_values_eq(-1)
return tokens.tolist()

View File

@ -25,26 +25,15 @@
from pathlib import Path
import k2
import torch
from lhotse import load_manifest
from lhotse import CutSet, load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from torch.nn.utils.rnn import pad_sequence
from lhotse.dataset.collation import collate_custom_field
from torch.utils.data import DataLoader
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import get_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
lang_dir = egs_dir / "data/lang_bpe_500"
cuts_json = egs_dir / "data/token_ali/cuts_test-clean.json.gz"
cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz"
def data_exists():
@ -53,10 +42,11 @@ def data_exists():
def get_dataloader():
cuts = load_manifest(cuts_json)
print(cuts[0])
cuts = cuts.with_features_path_prefix(egs_dir)
sampler = SingleCutSampler(
cuts,
max_duration=40,
max_duration=10,
shuffle=False,
)
@ -75,14 +65,24 @@ def get_dataloader():
def test():
if not data_exists():
return
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
dl = get_dataloader()
for batch in dl:
supervisions = batch["supervisions"]
cuts = supervisions["cut"]
print(cuts)
labels_alignment, labels_alignment_length = collate_custom_field(
CutSet.from_cuts(cuts), "labels_alignment"
)
(
aux_labels_alignment,
aux_labels_alignment_length,
) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment")
print(labels_alignment)
print(aux_labels_alignment)
print(labels_alignment_length)
print(aux_labels_alignment_length)
# print(cuts)
break