Compute word starting time from framewise token alignment.

This commit is contained in:
Fangjun Kuang 2022-03-07 15:16:54 +08:00
parent d50e7734a6
commit 75936a5fae
2 changed files with 174 additions and 18 deletions

View File

@ -32,7 +32,6 @@ import logging
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import k2
import numpy as np import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
@ -88,19 +87,14 @@ def get_parser():
help="""Output directory. help="""Output directory.
It contains 3 generated files: It contains 3 generated files:
- labels_xxx.h5 - token_ali_xxx.h5
- aux_labels_xxx.h5
- cuts_xxx.json.gz - cuts_xxx.json.gz
where xxx is the value of `--dataset`. For instance, if where xxx is the value of `--dataset`. For instance, if
`--dataset` is `train-clean-100`, it will contain 3 files: `--dataset` is `train-clean-100`, it will contain 2 files:
- `labels_train-clean-100.h5` - `token_ali_train-clean-100.h5`
- `aux_labels_train-clean-100.h5`
- `cuts_train-clean-100.json.gz` - `cuts_train-clean-100.json.gz`
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
alignment. The difference is that labels_xxx.h5 contains repeats.
""", """,
) )
@ -179,7 +173,6 @@ def compute_alignments(
ys_list: List[List[int]] = sp.encode(texts, out_type=int) ys_list: List[List[int]] = sp.encode(texts, out_type=int)
ali_list = [] ali_list = []
word_begin_time_list = []
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -208,7 +201,7 @@ def compute_alignments(
num_cuts += len(cut_list) num_cuts += len(cut_list)
if batch_idx % 100 == 0: if batch_idx % 2 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(
@ -255,13 +248,10 @@ def main():
out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5" out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5"
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
for f in ( done_file = out_dir / f".{params.dataset}.done"
out_ali_filename, if done_file.is_file():
out_manifest_filename, logging.info(f"{done_file} exists - skipping")
): exit()
if f.exists():
logging.info(f"{f} exists - skipping")
return
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
@ -329,6 +319,7 @@ def main():
f"saved to {out_ali_filename} and the cut manifest " f"saved to {out_ali_filename} and the cut manifest "
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
) )
done_file.touch()
# torch.set_num_threads(1) # torch.set_num_threads(1)

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script shows how to get word starting time
from framewise token alignment.
Usage:
./transducer_stateless/compute_ali.py \
--exp-dir ./transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--max-duration 300 \
--dataset train-clean-100 \
--out-dir data/ali
And the you can run:
./transducer_stateless/test_compute_ali.py \
--bpe-model ./data/lang_bpe_500/bpe.model \
--ali-dir data/ali \
--dataset train-clean-100
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from alignment import get_word_begin_frame
from lhotse import CutSet, load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from lhotse.dataset.collation import collate_custom_field
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--ali-dir",
type=Path,
default="./data/ali",
help="It specifies the directory where alignments can be found.",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="""The name of the dataset:
Possible values are:
- test-clean.
- test-other
- train-clean-100
- train-clean-360
- train-other-500
- dev-clean
- dev-other
""",
)
return parser
def main():
args = get_parser().parse_args()
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
cuts_json = args.ali_dir / f"cuts_{args.dataset}.json.gz"
logging.info(f"Loading {cuts_json}")
cuts = load_manifest(cuts_json)
sampler = SingleCutSampler(
cuts,
max_duration=30,
shuffle=False,
)
dataset = K2SpeechRecognitionDataset(return_cuts=True)
dl = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=None,
num_workers=1,
persistent_workers=False,
)
frame_shift = 10 # ms
subsampling_factor = 4
frame_shift_in_second = frame_shift * subsampling_factor / 1000.0
# key: cut.id
# value: a list of pairs (word, time_in_second)
word_begin_time_dict = {}
for batch in dl:
supervisions = batch["supervisions"]
cuts = supervisions["cut"]
token_alignment, token_alignment_length = collate_custom_field(
CutSet.from_cuts(cuts), "token_alignment"
)
for i in range(len(cuts)):
assert (
(cuts[i].features.num_frames - 1) // 2 - 1
) // 2 == token_alignment_length[i]
word_begin_frame = get_word_begin_frame(
token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
)
word_begin_time = [
"{:.2f}".format(i * frame_shift_in_second)
for i in word_begin_frame
]
words = supervisions["text"][i].split()
assert len(word_begin_frame) == len(words)
word_begin_time_dict[cuts[i].id] = list(zip(words, word_begin_time))
# This is a demo script and we exit here after processing
# one batch.
# You can find word starting time in the dict "word_begin_time_dict"
for cut_id, word_time in word_begin_time_dict.items():
print(f"{cut_id}\n{word_time}\n")
break
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()