Save framewise alignments with/without repeats.

This commit is contained in:
Fangjun Kuang 2021-11-19 15:32:52 +08:00
parent 62ada37d4e
commit c17527433d
4 changed files with 89 additions and 65 deletions

View File

@ -24,13 +24,12 @@ Usage:
--avg 10 \ --avg 10 \
--max-duration 300 \ --max-duration 300 \
--dataset train-clean-100 \ --dataset train-clean-100 \
--out-dir data/token-ali --out-dir data/ali
""" """
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple
import k2 import k2
import numpy as np import numpy as np
@ -49,7 +48,6 @@ from icefall.utils import (
encode_supervisions, encode_supervisions,
get_alignments, get_alignments,
get_env_info, get_env_info,
save_alignments,
setup_logger, setup_logger,
) )
@ -94,16 +92,21 @@ def get_parser():
type=str, type=str,
required=True, required=True,
help="""Output directory. help="""Output directory.
It contains the following generated files: It contains 3 generated files:
- xxx.h5 - labels_xxx.h5
- aux_labels_xxx.h5
- cuts_xxx.json.gz - cuts_xxx.json.gz
where xxx is the value of `--dataset`. For instance, if where xxx is the value of `--dataset`. For instance, if
`--dataset` is `train-clean-100`, it will contain two files: `--dataset` is `train-clean-100`, it will contain 3 files:
- `train-clean-100.h5` - `labels_train-clean-100.h5`
- `aux_labels_train-clean-100.h5`
- `cuts_train-clean-100.json.gz` - `cuts_train-clean-100.json.gz`
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
alignment. The difference is that labels_xxx.h5 contains repeats.
""", """,
) )
@ -149,7 +152,8 @@ def get_params() -> AttributeDict:
def compute_alignments( def compute_alignments(
model: torch.nn.Module, model: torch.nn.Module,
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
writer: FeaturesWriter, labels_writer: FeaturesWriter,
aux_labels_writer: FeaturesWriter,
params: AttributeDict, params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
) -> CutSet: ) -> CutSet:
@ -165,8 +169,10 @@ def compute_alignments(
graph_compiler: graph_compiler:
It converts token IDs to decoding graphs. It converts token IDs to decoding graphs.
Returns: Returns:
Return a CutSet. Each cut has a custom field `token_alignment` Return a CutSet. Each cut has two custom fields: labels_alignment
of type `lhotse.array.TemporalArray`. and aux_labels_alignment, containing framewise alignments information.
Both are of type `lhotse.array.TemporalArray`. The difference between
the two alignments is that `labels_alignment` contain repeats.
""" """
try: try:
num_batches = len(dl) num_batches = len(dl)
@ -204,7 +210,6 @@ def compute_alignments(
token_ids = graph_compiler.texts_to_ids(texts) token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids) decoding_graph = graph_compiler.compile(token_ids)
decoding_graph.tokens = decoding_graph.aux_labels.clone()
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output,
@ -223,21 +228,32 @@ def compute_alignments(
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
) )
ali_ids = get_alignments(best_path) labels_ali = get_alignments(best_path, kind="labels")
assert len(ali_ids) == len(cut_list) aux_labels_ali = get_alignments(best_path, kind="aux_labels")
for cut, ali in zip(cut_list, ali_ids): assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
for cut, labels, aux_labels in zip(
cut.token_alignment = writer.store_array( cut_list, labels_ali, aux_labels_ali
):
cut.labels_alignment = labels_writer.store_array(
key=cut.id, key=cut.id,
value=np.asarray(ali, dtype=np.int32), value=np.asarray(labels, dtype=np.int32),
frame_shift=0.04, # frame shift is 0.01s, subsampling_factor is 4 # frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
cut.aux_labels_alignment = aux_labels_writer.store_array(
key=cut.id,
value=np.asarray(aux_labels, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0, temporal_dim=0,
start=0, start=0,
) )
cuts += cut_list cuts += cut_list
num_cuts += len(ali_ids) num_cuts += len(cut_list)
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
@ -271,14 +287,17 @@ def main():
out_dir = Path(params.out_dir) out_dir = Path(params.out_dir)
out_dir.mkdir(exist_ok=True) out_dir.mkdir(exist_ok=True)
out_ali_filename = out_dir / f"{params.dataset}.h5" out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
if out_ali_filename.exists(): out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
logging.info(f"{out_ali_filename} exists - skipping")
return
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
if out_manifest_filename.exists():
logging.info(f"{out_manifest_filename} exists - skipping") for f in (
out_labels_ali_filename,
out_aux_labels_ali_filename,
out_manifest_filename,
):
if f.exists():
logging.info(f"{f} exists - skipping")
return return
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
@ -352,21 +371,24 @@ def main():
dl = librispeech.valid_dataloaders(dev_other_cuts) dl = librispeech.valid_dataloaders(dev_other_cuts)
logging.info(f"Processing {params.dataset}") logging.info(f"Processing {params.dataset}")
with NumpyHdf5Writer(out_ali_filename) as writer: with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
cut_set = compute_alignments( cut_set = compute_alignments(
model=model, model=model,
dl=dl, dl=dl,
writer=writer, labels_writer=labels_writer,
aux_labels_writer=aux_labels_writer,
params=params, params=params,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
) )
cut_set.to_json(out_manifest_filename) cut_set.to_file(out_manifest_filename)
logging.info( logging.info(
f"For dataset {params.dataset}, its alignments are " f"For dataset {params.dataset}, its alignments with repeats are "
f"saved to {out_ali_filename} and the cut manifest file " f"saved to {out_labels_ali_filename}, the alignments without repeats "
f"is {out_manifest_filename}. Number of cuts: {len(cut_set)}" f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
) )

View File

@ -19,7 +19,6 @@ import argparse
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (

View File

@ -305,11 +305,8 @@ def get_texts(
return aux_labels.tolist() return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
"""Extract the token IDs (from best_paths.tokens) from the best-path FSAs. """Extract labels or aux_labels from the best-path FSAs.
Caution:
There are no repeats in `best_paths.tokens`.
Args: Args:
best_paths: best_paths:
@ -317,15 +314,21 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
containing multiple FSAs, which is expected to be the result containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't of k2.shortest_path (otherwise the returned values won't
be meaningful). be meaningful).
kind:
Possible values are: "labels" and "aux_labels". Caution: When it is
"labels", the resulting alignments contain repeats.
Returns: Returns:
Returns a list of lists of int, containing the token sequences we Returns a list of lists of int, containing the token sequences we
decoded. For `ans[i]`, its length equals to the number of frames decoded. For `ans[i]`, its length equals to the number of frames
after subsampling of the i-th utterance in the batch. after subsampling of the i-th utterance in the batch.
""" """
assert kind in ("labels", "aux_labels")
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
token_shape = best_paths.arcs.shape().remove_axis(1) token_shape = best_paths.arcs.shape().remove_axis(1)
# token_shape has axes [fsa][arc] # token_shape has axes [fsa][arc]
tokens = k2.RaggedTensor(token_shape, best_paths.tokens) tokens = k2.RaggedTensor(
token_shape, getattr(best_paths, kind).contiguous()
)
tokens = tokens.remove_values_eq(-1) tokens = tokens.remove_values_eq(-1)
return tokens.tolist() return tokens.tolist()

View File

@ -25,26 +25,15 @@
from pathlib import Path from pathlib import Path
import k2 from lhotse import CutSet, load_manifest
import torch
from lhotse import load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from torch.nn.utils.rnn import pad_sequence from lhotse.dataset.collation import collate_custom_field
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import get_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent ICEFALL_DIR = Path(__file__).resolve().parent.parent
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR" egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
lang_dir = egs_dir / "data/lang_bpe_500" lang_dir = egs_dir / "data/lang_bpe_500"
cuts_json = egs_dir / "data/token_ali/cuts_test-clean.json.gz" cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz"
def data_exists(): def data_exists():
@ -53,10 +42,11 @@ def data_exists():
def get_dataloader(): def get_dataloader():
cuts = load_manifest(cuts_json) cuts = load_manifest(cuts_json)
print(cuts[0])
cuts = cuts.with_features_path_prefix(egs_dir) cuts = cuts.with_features_path_prefix(egs_dir)
sampler = SingleCutSampler( sampler = SingleCutSampler(
cuts, cuts,
max_duration=40, max_duration=10,
shuffle=False, shuffle=False,
) )
@ -75,14 +65,24 @@ def get_dataloader():
def test(): def test():
if not data_exists(): if not data_exists():
return return
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
dl = get_dataloader() dl = get_dataloader()
for batch in dl: for batch in dl:
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
cuts = supervisions["cut"] cuts = supervisions["cut"]
print(cuts) labels_alignment, labels_alignment_length = collate_custom_field(
CutSet.from_cuts(cuts), "labels_alignment"
)
(
aux_labels_alignment,
aux_labels_alignment_length,
) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment")
print(labels_alignment)
print(aux_labels_alignment)
print(labels_alignment_length)
print(aux_labels_alignment_length)
# print(cuts)
break break