mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Save framewise alignments with/without repeats.
This commit is contained in:
parent
62ada37d4e
commit
c17527433d
@ -24,13 +24,12 @@ Usage:
|
||||
--avg 10 \
|
||||
--max-duration 300 \
|
||||
--dataset train-clean-100 \
|
||||
--out-dir data/token-ali
|
||||
--out-dir data/ali
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
@ -49,7 +48,6 @@ from icefall.utils import (
|
||||
encode_supervisions,
|
||||
get_alignments,
|
||||
get_env_info,
|
||||
save_alignments,
|
||||
setup_logger,
|
||||
)
|
||||
|
||||
@ -94,16 +92,21 @@ def get_parser():
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Output directory.
|
||||
It contains the following generated files:
|
||||
It contains 3 generated files:
|
||||
|
||||
- xxx.h5
|
||||
- labels_xxx.h5
|
||||
- aux_labels_xxx.h5
|
||||
- cuts_xxx.json.gz
|
||||
|
||||
where xxx is the value of `--dataset`. For instance, if
|
||||
`--dataset` is `train-clean-100`, it will contain two files:
|
||||
`--dataset` is `train-clean-100`, it will contain 3 files:
|
||||
|
||||
- `train-clean-100.h5`
|
||||
- `labels_train-clean-100.h5`
|
||||
- `aux_labels_train-clean-100.h5`
|
||||
- `cuts_train-clean-100.json.gz`
|
||||
|
||||
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
|
||||
alignment. The difference is that labels_xxx.h5 contains repeats.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -149,7 +152,8 @@ def get_params() -> AttributeDict:
|
||||
def compute_alignments(
|
||||
model: torch.nn.Module,
|
||||
dl: torch.utils.data.DataLoader,
|
||||
writer: FeaturesWriter,
|
||||
labels_writer: FeaturesWriter,
|
||||
aux_labels_writer: FeaturesWriter,
|
||||
params: AttributeDict,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
) -> CutSet:
|
||||
@ -165,8 +169,10 @@ def compute_alignments(
|
||||
graph_compiler:
|
||||
It converts token IDs to decoding graphs.
|
||||
Returns:
|
||||
Return a CutSet. Each cut has a custom field `token_alignment`
|
||||
of type `lhotse.array.TemporalArray`.
|
||||
Return a CutSet. Each cut has two custom fields: labels_alignment
|
||||
and aux_labels_alignment, containing framewise alignments information.
|
||||
Both are of type `lhotse.array.TemporalArray`. The difference between
|
||||
the two alignments is that `labels_alignment` contain repeats.
|
||||
"""
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
@ -204,7 +210,6 @@ def compute_alignments(
|
||||
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
decoding_graph.tokens = decoding_graph.aux_labels.clone()
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
@ -223,21 +228,32 @@ def compute_alignments(
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
ali_ids = get_alignments(best_path)
|
||||
assert len(ali_ids) == len(cut_list)
|
||||
for cut, ali in zip(cut_list, ali_ids):
|
||||
|
||||
cut.token_alignment = writer.store_array(
|
||||
labels_ali = get_alignments(best_path, kind="labels")
|
||||
aux_labels_ali = get_alignments(best_path, kind="aux_labels")
|
||||
assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
|
||||
for cut, labels, aux_labels in zip(
|
||||
cut_list, labels_ali, aux_labels_ali
|
||||
):
|
||||
cut.labels_alignment = labels_writer.store_array(
|
||||
key=cut.id,
|
||||
value=np.asarray(ali, dtype=np.int32),
|
||||
frame_shift=0.04, # frame shift is 0.01s, subsampling_factor is 4
|
||||
value=np.asarray(labels, dtype=np.int32),
|
||||
# frame shift is 0.01s, subsampling_factor is 4
|
||||
frame_shift=0.04,
|
||||
temporal_dim=0,
|
||||
start=0,
|
||||
)
|
||||
cut.aux_labels_alignment = aux_labels_writer.store_array(
|
||||
key=cut.id,
|
||||
value=np.asarray(aux_labels, dtype=np.int32),
|
||||
# frame shift is 0.01s, subsampling_factor is 4
|
||||
frame_shift=0.04,
|
||||
temporal_dim=0,
|
||||
start=0,
|
||||
)
|
||||
|
||||
cuts += cut_list
|
||||
|
||||
num_cuts += len(ali_ids)
|
||||
num_cuts += len(cut_list)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
@ -271,15 +287,18 @@ def main():
|
||||
out_dir = Path(params.out_dir)
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
|
||||
out_ali_filename = out_dir / f"{params.dataset}.h5"
|
||||
if out_ali_filename.exists():
|
||||
logging.info(f"{out_ali_filename} exists - skipping")
|
||||
return
|
||||
|
||||
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
|
||||
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
|
||||
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
|
||||
if out_manifest_filename.exists():
|
||||
logging.info(f"{out_manifest_filename} exists - skipping")
|
||||
return
|
||||
|
||||
for f in (
|
||||
out_labels_ali_filename,
|
||||
out_aux_labels_ali_filename,
|
||||
out_manifest_filename,
|
||||
):
|
||||
if f.exists():
|
||||
logging.info(f"{f} exists - skipping")
|
||||
return
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
@ -352,21 +371,24 @@ def main():
|
||||
dl = librispeech.valid_dataloaders(dev_other_cuts)
|
||||
|
||||
logging.info(f"Processing {params.dataset}")
|
||||
with NumpyHdf5Writer(out_ali_filename) as writer:
|
||||
cut_set = compute_alignments(
|
||||
model=model,
|
||||
dl=dl,
|
||||
writer=writer,
|
||||
params=params,
|
||||
graph_compiler=graph_compiler,
|
||||
)
|
||||
with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
|
||||
with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
|
||||
cut_set = compute_alignments(
|
||||
model=model,
|
||||
dl=dl,
|
||||
labels_writer=labels_writer,
|
||||
aux_labels_writer=aux_labels_writer,
|
||||
params=params,
|
||||
graph_compiler=graph_compiler,
|
||||
)
|
||||
|
||||
cut_set.to_json(out_manifest_filename)
|
||||
cut_set.to_file(out_manifest_filename)
|
||||
|
||||
logging.info(
|
||||
f"For dataset {params.dataset}, its alignments are "
|
||||
f"saved to {out_ali_filename} and the cut manifest file "
|
||||
f"is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
|
||||
f"For dataset {params.dataset}, its alignments with repeats are "
|
||||
f"saved to {out_labels_ali_filename}, the alignments without repeats "
|
||||
f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
|
||||
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,7 +19,6 @@ import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
|
@ -305,11 +305,8 @@ def get_texts(
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
"""Extract the token IDs (from best_paths.tokens) from the best-path FSAs.
|
||||
|
||||
Caution:
|
||||
There are no repeats in `best_paths.tokens`.
|
||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||
"""Extract labels or aux_labels from the best-path FSAs.
|
||||
|
||||
Args:
|
||||
best_paths:
|
||||
@ -317,15 +314,21 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
kind:
|
||||
Possible values are: "labels" and "aux_labels". Caution: When it is
|
||||
"labels", the resulting alignments contain repeats.
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the token sequences we
|
||||
decoded. For `ans[i]`, its length equals to the number of frames
|
||||
after subsampling of the i-th utterance in the batch.
|
||||
"""
|
||||
assert kind in ("labels", "aux_labels")
|
||||
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
|
||||
token_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
# token_shape has axes [fsa][arc]
|
||||
tokens = k2.RaggedTensor(token_shape, best_paths.tokens)
|
||||
tokens = k2.RaggedTensor(
|
||||
token_shape, getattr(best_paths, kind).contiguous()
|
||||
)
|
||||
tokens = tokens.remove_values_eq(-1)
|
||||
return tokens.tolist()
|
||||
|
||||
|
@ -25,26 +25,15 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from lhotse import load_manifest
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import get_texts
|
||||
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
|
||||
lang_dir = egs_dir / "data/lang_bpe_500"
|
||||
cuts_json = egs_dir / "data/token_ali/cuts_test-clean.json.gz"
|
||||
cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz"
|
||||
|
||||
|
||||
def data_exists():
|
||||
@ -53,10 +42,11 @@ def data_exists():
|
||||
|
||||
def get_dataloader():
|
||||
cuts = load_manifest(cuts_json)
|
||||
print(cuts[0])
|
||||
cuts = cuts.with_features_path_prefix(egs_dir)
|
||||
sampler = SingleCutSampler(
|
||||
cuts,
|
||||
max_duration=40,
|
||||
max_duration=10,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
@ -75,14 +65,24 @@ def get_dataloader():
|
||||
def test():
|
||||
if not data_exists():
|
||||
return
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
dl = get_dataloader()
|
||||
for batch in dl:
|
||||
supervisions = batch["supervisions"]
|
||||
cuts = supervisions["cut"]
|
||||
print(cuts)
|
||||
labels_alignment, labels_alignment_length = collate_custom_field(
|
||||
CutSet.from_cuts(cuts), "labels_alignment"
|
||||
)
|
||||
|
||||
(
|
||||
aux_labels_alignment,
|
||||
aux_labels_alignment_length,
|
||||
) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment")
|
||||
|
||||
print(labels_alignment)
|
||||
print(aux_labels_alignment)
|
||||
print(labels_alignment_length)
|
||||
print(aux_labels_alignment_length)
|
||||
# print(cuts)
|
||||
break
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user