mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
164 lines
4.6 KiB
Python
Executable File
164 lines
4.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
"""
|
|
This script shows how to get word starting time
|
|
from framewise token alignment.
|
|
|
|
Usage:
|
|
./transducer_stateless/compute_ali.py \
|
|
--exp-dir ./transducer_stateless/exp \
|
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
--epoch 20 \
|
|
--avg 10 \
|
|
--max-duration 300 \
|
|
--dataset train-clean-100 \
|
|
--out-dir data/ali
|
|
|
|
And the you can run:
|
|
|
|
./transducer_stateless/test_compute_ali.py \
|
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
--ali-dir data/ali \
|
|
--dataset train-clean-100
|
|
"""
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import sentencepiece as spm
|
|
import torch
|
|
from alignment import get_word_starting_frames
|
|
from lhotse import CutSet, load_manifest_lazy
|
|
from lhotse.dataset import DynamicBucketingSampler, K2SpeechRecognitionDataset
|
|
from lhotse.dataset.collation import collate_custom_field
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bpe-model",
|
|
type=str,
|
|
default="data/lang_bpe_500/bpe.model",
|
|
help="Path to the BPE model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--ali-dir",
|
|
type=Path,
|
|
default="./data/ali",
|
|
help="It specifies the directory where alignments can be found.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dataset",
|
|
type=str,
|
|
required=True,
|
|
help="""The name of the dataset:
|
|
Possible values are:
|
|
- test-clean.
|
|
- test-other
|
|
- train-clean-100
|
|
- train-clean-360
|
|
- train-other-500
|
|
- dev-clean
|
|
- dev-other
|
|
""",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def main():
|
|
args = get_parser().parse_args()
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(args.bpe_model)
|
|
|
|
cuts_jsonl = args.ali_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
|
|
|
|
logging.info(f"Loading {cuts_jsonl}")
|
|
cuts = load_manifest_lazy(cuts_jsonl)
|
|
|
|
sampler = DynamicBucketingSampler(
|
|
cuts,
|
|
max_duration=30,
|
|
num_buckets=30,
|
|
shuffle=False,
|
|
)
|
|
|
|
dataset = K2SpeechRecognitionDataset(return_cuts=True)
|
|
|
|
dl = torch.utils.data.DataLoader(
|
|
dataset,
|
|
sampler=sampler,
|
|
batch_size=None,
|
|
num_workers=1,
|
|
persistent_workers=False,
|
|
)
|
|
|
|
frame_shift = 10 # ms
|
|
subsampling_factor = 4
|
|
|
|
frame_shift_in_second = frame_shift * subsampling_factor / 1000.0
|
|
|
|
# key: cut.id
|
|
# value: a list of pairs (word, time_in_second)
|
|
word_starting_time_dict = {}
|
|
for batch in dl:
|
|
supervisions = batch["supervisions"]
|
|
cuts = supervisions["cut"]
|
|
|
|
token_alignment, token_alignment_length = collate_custom_field(
|
|
CutSet.from_cuts(cuts), "token_alignment"
|
|
)
|
|
|
|
for i in range(len(cuts)):
|
|
assert (
|
|
(cuts[i].features.num_frames - 1) // 2 - 1
|
|
) // 2 == token_alignment_length[i]
|
|
|
|
word_starting_frames = get_word_starting_frames(
|
|
token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
|
|
)
|
|
word_starting_time = [
|
|
"{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
|
|
]
|
|
|
|
words = supervisions["text"][i].split()
|
|
|
|
assert len(word_starting_frames) == len(words)
|
|
word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
|
|
|
|
# This is a demo script and we exit here after processing
|
|
# one batch.
|
|
# You can find word starting time in the dict "word_starting_time_dict"
|
|
for cut_id, word_time in word_starting_time_dict.items():
|
|
print(f"{cut_id}\n{word_time}\n")
|
|
break
|
|
|
|
|
|
if __name__ == "__main__":
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
main()
|