diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py index dd0665326..eaf102e92 100755 --- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py +++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py @@ -32,7 +32,6 @@ import logging from pathlib import Path from typing import List -import k2 import numpy as np import sentencepiece as spm import torch @@ -88,19 +87,14 @@ def get_parser(): help="""Output directory. It contains 3 generated files: - - labels_xxx.h5 - - aux_labels_xxx.h5 + - token_ali_xxx.h5 - cuts_xxx.json.gz where xxx is the value of `--dataset`. For instance, if - `--dataset` is `train-clean-100`, it will contain 3 files: + `--dataset` is `train-clean-100`, it will contain 2 files: - - `labels_train-clean-100.h5` - - `aux_labels_train-clean-100.h5` + - `token_ali_train-clean-100.h5` - `cuts_train-clean-100.json.gz` - - Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise - alignment. The difference is that labels_xxx.h5 contains repeats. """, ) @@ -179,7 +173,6 @@ def compute_alignments( ys_list: List[List[int]] = sp.encode(texts, out_type=int) ali_list = [] - word_begin_time_list = [] for i in range(batch_size): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] @@ -208,7 +201,7 @@ def compute_alignments( num_cuts += len(cut_list) - if batch_idx % 100 == 0: + if batch_idx % 2 == 0: batch_str = f"{batch_idx}/{num_batches}" logging.info( @@ -255,13 +248,10 @@ def main(): out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5" out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" - for f in ( - out_ali_filename, - out_manifest_filename, - ): - if f.exists(): - logging.info(f"{f} exists - skipping") - return + done_file = out_dir / f".{params.dataset}.done" + if done_file.is_file(): + logging.info(f"{done_file} exists - skipping") + exit() logging.info("About to create model") model = get_transducer_model(params) @@ -329,6 +319,7 @@ def main(): f"saved to {out_ali_filename} and the cut manifest " f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" ) + done_file.touch() # torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py new file mode 100755 index 000000000..ffb270ae7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script shows how to get word starting time +from framewise token alignment. + +Usage: + ./transducer_stateless/compute_ali.py \ + --exp-dir ./transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --max-duration 300 \ + --dataset train-clean-100 \ + --out-dir data/ali + +And the you can run: + + ./transducer_stateless/test_compute_ali.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --ali-dir data/ali \ + --dataset train-clean-100 +""" +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from alignment import get_word_begin_frame +from lhotse import CutSet, load_manifest +from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler +from lhotse.dataset.collation import collate_custom_field + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--ali-dir", + type=Path, + default="./data/ali", + help="It specifies the directory where alignments can be found.", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset: + Possible values are: + - test-clean. + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + cuts_json = args.ali_dir / f"cuts_{args.dataset}.json.gz" + + logging.info(f"Loading {cuts_json}") + cuts = load_manifest(cuts_json) + + sampler = SingleCutSampler( + cuts, + max_duration=30, + shuffle=False, + ) + + dataset = K2SpeechRecognitionDataset(return_cuts=True) + + dl = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=1, + persistent_workers=False, + ) + + frame_shift = 10 # ms + subsampling_factor = 4 + + frame_shift_in_second = frame_shift * subsampling_factor / 1000.0 + + # key: cut.id + # value: a list of pairs (word, time_in_second) + word_begin_time_dict = {} + for batch in dl: + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + token_alignment, token_alignment_length = collate_custom_field( + CutSet.from_cuts(cuts), "token_alignment" + ) + + for i in range(len(cuts)): + assert ( + (cuts[i].features.num_frames - 1) // 2 - 1 + ) // 2 == token_alignment_length[i] + + word_begin_frame = get_word_begin_frame( + token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp + ) + word_begin_time = [ + "{:.2f}".format(i * frame_shift_in_second) + for i in word_begin_frame + ] + + words = supervisions["text"][i].split() + + assert len(word_begin_frame) == len(words) + word_begin_time_dict[cuts[i].id] = list(zip(words, word_begin_time)) + + # This is a demo script and we exit here after processing + # one batch. + # You can find word starting time in the dict "word_begin_time_dict" + for cut_id, word_time in word_begin_time_dict.items(): + print(f"{cut_id}\n{word_time}\n") + break + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main()