mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Compute word starting time from framewise token alignment.
This commit is contained in:
parent
d50e7734a6
commit
75936a5fae
@ -32,7 +32,6 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -88,19 +87,14 @@ def get_parser():
|
||||
help="""Output directory.
|
||||
It contains 3 generated files:
|
||||
|
||||
- labels_xxx.h5
|
||||
- aux_labels_xxx.h5
|
||||
- token_ali_xxx.h5
|
||||
- cuts_xxx.json.gz
|
||||
|
||||
where xxx is the value of `--dataset`. For instance, if
|
||||
`--dataset` is `train-clean-100`, it will contain 3 files:
|
||||
`--dataset` is `train-clean-100`, it will contain 2 files:
|
||||
|
||||
- `labels_train-clean-100.h5`
|
||||
- `aux_labels_train-clean-100.h5`
|
||||
- `token_ali_train-clean-100.h5`
|
||||
- `cuts_train-clean-100.json.gz`
|
||||
|
||||
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
|
||||
alignment. The difference is that labels_xxx.h5 contains repeats.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -179,7 +173,6 @@ def compute_alignments(
|
||||
ys_list: List[List[int]] = sp.encode(texts, out_type=int)
|
||||
|
||||
ali_list = []
|
||||
word_begin_time_list = []
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
@ -208,7 +201,7 @@ def compute_alignments(
|
||||
|
||||
num_cuts += len(cut_list)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
if batch_idx % 2 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
@ -255,13 +248,10 @@ def main():
|
||||
out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5"
|
||||
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
|
||||
|
||||
for f in (
|
||||
out_ali_filename,
|
||||
out_manifest_filename,
|
||||
):
|
||||
if f.exists():
|
||||
logging.info(f"{f} exists - skipping")
|
||||
return
|
||||
done_file = out_dir / f".{params.dataset}.done"
|
||||
if done_file.is_file():
|
||||
logging.info(f"{done_file} exists - skipping")
|
||||
exit()
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
@ -329,6 +319,7 @@ def main():
|
||||
f"saved to {out_ali_filename} and the cut manifest "
|
||||
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
|
||||
)
|
||||
done_file.touch()
|
||||
|
||||
|
||||
# torch.set_num_threads(1)
|
||||
|
165
egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
Executable file
165
egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
Executable file
@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script shows how to get word starting time
|
||||
from framewise token alignment.
|
||||
|
||||
Usage:
|
||||
./transducer_stateless/compute_ali.py \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--max-duration 300 \
|
||||
--dataset train-clean-100 \
|
||||
--out-dir data/ali
|
||||
|
||||
And the you can run:
|
||||
|
||||
./transducer_stateless/test_compute_ali.py \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--ali-dir data/ali \
|
||||
--dataset train-clean-100
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from alignment import get_word_begin_frame
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ali-dir",
|
||||
type=Path,
|
||||
default="./data/ali",
|
||||
help="It specifies the directory where alignments can be found.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""The name of the dataset:
|
||||
Possible values are:
|
||||
- test-clean.
|
||||
- test-other
|
||||
- train-clean-100
|
||||
- train-clean-360
|
||||
- train-other-500
|
||||
- dev-clean
|
||||
- dev-other
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
cuts_json = args.ali_dir / f"cuts_{args.dataset}.json.gz"
|
||||
|
||||
logging.info(f"Loading {cuts_json}")
|
||||
cuts = load_manifest(cuts_json)
|
||||
|
||||
sampler = SingleCutSampler(
|
||||
cuts,
|
||||
max_duration=30,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
dataset = K2SpeechRecognitionDataset(return_cuts=True)
|
||||
|
||||
dl = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=None,
|
||||
num_workers=1,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
frame_shift = 10 # ms
|
||||
subsampling_factor = 4
|
||||
|
||||
frame_shift_in_second = frame_shift * subsampling_factor / 1000.0
|
||||
|
||||
# key: cut.id
|
||||
# value: a list of pairs (word, time_in_second)
|
||||
word_begin_time_dict = {}
|
||||
for batch in dl:
|
||||
supervisions = batch["supervisions"]
|
||||
cuts = supervisions["cut"]
|
||||
|
||||
token_alignment, token_alignment_length = collate_custom_field(
|
||||
CutSet.from_cuts(cuts), "token_alignment"
|
||||
)
|
||||
|
||||
for i in range(len(cuts)):
|
||||
assert (
|
||||
(cuts[i].features.num_frames - 1) // 2 - 1
|
||||
) // 2 == token_alignment_length[i]
|
||||
|
||||
word_begin_frame = get_word_begin_frame(
|
||||
token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
|
||||
)
|
||||
word_begin_time = [
|
||||
"{:.2f}".format(i * frame_shift_in_second)
|
||||
for i in word_begin_frame
|
||||
]
|
||||
|
||||
words = supervisions["text"][i].split()
|
||||
|
||||
assert len(word_begin_frame) == len(words)
|
||||
word_begin_time_dict[cuts[i].id] = list(zip(words, word_begin_time))
|
||||
|
||||
# This is a demo script and we exit here after processing
|
||||
# one batch.
|
||||
# You can find word starting time in the dict "word_begin_time_dict"
|
||||
for cut_id, word_time in word_begin_time_dict.items():
|
||||
print(f"{cut_id}\n{word_time}\n")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user