diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 3b2678ec4..1bbf7bbcf 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -60,8 +60,11 @@ log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "Stage -1: Download LM" - [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm - ./local/download_lm.py --out-dir=$dl_dir/lm + mkdir -p $dl_dir/lm + if [ ! -e $dl_dir/lm/.done ]; then + ./local/download_lm.py --out-dir=$dl_dir/lm + touch $dl_dir/lm/.done + fi fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then @@ -91,7 +94,10 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # We assume that you have downloaded the LibriSpeech corpus # to $dl_dir/LibriSpeech mkdir -p data/manifests - lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests + if [ ! -e data/manifests/.librispeech.done ]; then + lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests + touch data/manifests/.librispeech.done + fi fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -99,19 +105,28 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # We assume that you have downloaded the musan corpus # to data/musan mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for librispeech" mkdir -p data/fbank - ./local/compute_fbank_librispeech.py + if [ ! -e data/fbank/.librispeech.done ]; then + ./local/compute_fbank_librispeech.py + touch data/fbank/.librispeech.done + fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for musan" mkdir -p data/fbank - ./local/compute_fbank_musan.py + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e075a2d03..2af2f5e8a 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -180,14 +180,14 @@ class LibriSpeechAsrDataModule: ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: - logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) transforms.append( CutMix( cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True diff --git a/egs/librispeech/ASR/transducer_stateless/README.md b/egs/librispeech/ASR/transducer_stateless/README.md index 964bddfab..978fa2ada 100644 --- a/egs/librispeech/ASR/transducer_stateless/README.md +++ b/egs/librispeech/ASR/transducer_stateless/README.md @@ -20,3 +20,120 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --max-duration 250 \ --lr-factor 2.5 ``` + +## How to get framewise token alignment + +Assume that you already have a trained model. If not, you can either +train one by yourself or download a pre-trained model from hugging face: + + +**Caution**: If you are going to use your own trained model, remember +to set `--modified-transducer-prob` to a nonzero value since the +force alignment code assumes that `--max-sym-per-frame` is 1. + + +The following shows how to get framewise token alignment using the above +pre-trained model. + +```bash +git clone https://github.com/k2-fsa/icefall +cd icefall/egs/librispeech/ASR +mkdir tmp +sudo apt-get install git-lfs +git lfs install +git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01 ./tmp/ + +ln -s $PWD/tmp/exp/pretrained.pt $PWD/tmp/epoch-999.pt + +./transducer_stateless/compute_ali.py \ + --exp-dir ./tmp/exp \ + --bpe-model ./tmp/data/lang_bpe_500/bpe.model \ + --epoch 999 \ + --avg 1 \ + --max-duration 100 \ + --dataset dev-clean \ + --out-dir data/ali +``` + +After running the above commands, you will find the following two files +in the folder `./data/ali`: + +``` +-rw-r--r-- 1 xxx xxx 412K Mar 7 15:45 cuts_dev-clean.json.gz +-rw-r--r-- 1 xxx xxx 2.9M Mar 7 15:45 token_ali_dev-clean.h5 +``` + +You can find usage examples in `./test_compute_ali.py` about +extracting framewise token alignment information from the above +two files. + +## How to get word starting time from framewise token alignment + +Assume you have run the above commands to get framewise token alignment +using a pre-trained model from `tmp/exp/epoch-999.pt`. You can use the following +commands to obtain word starting time. + +```bash +./transducer_stateless/test_compute_ali.py \ + --bpe-model ./tmp/data/lang_bpe_500/bpe.model \ + --ali-dir data/ali \ + --dataset dev-clean +``` + +**Caution**: Since the frame shift is 10ms and the subsampling factor +of the model is 4, the time resolution is 0.04 second. + +**Note**: The script `test_compute_ali.py` is for illustration only +and it processes only one batch and then exits. + +You will get the following output: + +``` +5694-64029-0022-1998-0 +[('THE', '0.20'), ('LEADEN', '0.36'), ('HAIL', '0.72'), ('STORM', '1.00'), ('SWEPT', '1.48'), ('THEM', '1.88'), ('OFF', '2.00'), ('THE', '2.24'), ('FIELD', '2.36'), ('THEY', '3.20'), ('FELL', '3.36'), ('BACK', '3.64'), ('AND', '3.92'), ('RE', '4.04'), ('FORMED', '4.20')] + +3081-166546-0040-308-0 +[('IN', '0.32'), ('OLDEN', '0.60'), ('DAYS', '1.00'), ('THEY', '1.40'), ('WOULD', '1.56'), ('HAVE', '1.76'), ('SAID', '1.92'), ('STRUCK', '2.60'), ('BY', '3.16'), ('A', '3.36'), ('BOLT', '3.44'), ('FROM', '3.84'), ('HEAVEN', '4.04')] + +2035-147960-0016-1283-0 +[('A', '0.44'), ('SNAKE', '0.52'), ('OF', '0.84'), ('HIS', '0.96'), ('SIZE', '1.12'), ('IN', '1.60'), ('FIGHTING', '1.72'), ('TRIM', '2.12'), ('WOULD', '2.56'), ('BE', '2.76'), ('MORE', '2.88'), ('THAN', '3.08'), ('ANY', '3.28'), ('BOY', '3.56'), ('COULD', '3.88'), ('HANDLE', '4.04')] + +2428-83699-0020-1734-0 +[('WHEN', '0.28'), ('THE', '0.48'), ('TRAP', '0.60'), ('DID', '0.88'), ('APPEAR', '1.08'), ('IT', '1.80'), ('LOOKED', '1.96'), ('TO', +'2.24'), ('ME', '2.36'), ('UNCOMMONLY', '2.52'), ('LIKE', '3.16'), ('AN', '3.40'), ('OPEN', '3.56'), ('SPRING', '3.92'), ('CART', '4.28')] + +8297-275154-0026-2108-0 +[('LET', '0.44'), ('ME', '0.72'), ('REST', '0.92'), ('A', '1.32'), ('LITTLE', '1.40'), ('HE', '1.80'), ('PLEADED', '2.00'), ('IF', '3.04'), ("I'M", '3.28'), ('NOT', '3.52'), ('IN', '3.76'), ('THE', '3.88'), ('WAY', '4.00')] + +652-129742-0007-1002-0 +[('SURROUND', '0.28'), ('WITH', '0.80'), ('A', '0.92'), ('GARNISH', '1.00'), ('OF', '1.44'), ('COOKED', '1.56'), ('AND', '1.88'), ('DICED', '4.16'), ('CARROTS', '4.28'), ('TURNIPS', '4.44'), ('GREEN', '4.60'), ('PEAS', '4.72')] +``` + + +For the row: +``` +5694-64029-0022-1998-0 +[('THE', '0.20'), ('LEADEN', '0.36'), ('HAIL', '0.72'), ('STORM', '1.00'), ('SWEPT', '1.48'), +('THEM', '1.88'), ('OFF', '2.00'), ('THE', '2.24'), ('FIELD', '2.36'), ('THEY', '3.20'), ('FELL', '3.36'), +('BACK', '3.64'), ('AND', '3.92'), ('RE', '4.04'), ('FORMED', '4.20')] +``` + +- `5694-64029-0022-1998-0` is the cut ID. +- `('THE', '0.20')` means the word `THE` starts at 0.20 second. +- `('LEADEN', '0.36')` means the word `LEADEN` starts at 0.36 second. + + +You can compare the above word starting time with the one +from + +``` +5694-64029-0022 ",THE,LEADEN,HAIL,STORM,SWEPT,THEM,OFF,THE,FIELD,,THEY,FELL,BACK,AND,RE,FORMED," "0.230,0.360,0.670,1.010,1.440,1.860,1.990,2.230,2.350,2.870,3.230,3.390,3.660,3.960,4.060,4.160,4.850,4.9" +``` + +We reformat it below for readability: + +``` +5694-64029-0022 ",THE,LEADEN,HAIL,STORM,SWEPT,THEM,OFF,THE,FIELD,,THEY,FELL,BACK,AND,RE,FORMED," +"0.230,0.360,0.670,1.010,1.440,1.860,1.990,2.230,2.350,2.870,3.230,3.390,3.660,3.960,4.060,4.160,4.850,4.9" + the leaden hail storm swept them off the field sil they fell back and re formed sil +``` diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py new file mode 100644 index 000000000..f143611ea --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/alignment.py @@ -0,0 +1,268 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Iterator, List, Optional + +import sentencepiece as spm +import torch +from model import Transducer + +# The force alignment problem can be formulated as finding +# a path in a rectangular lattice, where the path starts +# from the lower left corner and ends at the upper right +# corner. The horizontal axis of the lattice is `t` (representing +# acoustic frame indexes) and the vertical axis is `u` (representing +# BPE tokens of the transcript). +# +# The notations `t` and `u` are from the paper +# https://arxiv.org/pdf/1211.3711.pdf +# +# Beam search is used to find the path with the +# highest log probabilities. +# +# It assumes the maximum number of symbols that can be +# emitted per frame is 1. You can use `--modified-transducer-prob` +# from `./train.py` to train a model that satisfies this assumption. + + +# AlignItem is the ending node of a path originated from the starting node. +# len(ys) equals to `t` and pos_u is the u coordinate +# in the lattice. +@dataclass +class AlignItem: + # total log prob of the path that ends at this item. + # The path is originated from the starting node. + log_prob: float + + # It contains framewise token alignment + ys: List[int] + + # It equals to the number of non-zero entries in ys + pos_u: int + + +class AlignItemList: + def __init__(self, items: Optional[List[AlignItem]] = None): + """ + Args: + items: + A list of AlignItem + """ + if items is None: + items = [] + self.data = items + + def __iter__(self) -> Iterator: + return iter(self.data) + + def __len__(self) -> int: + """Return the number of AlignItem in this object.""" + return len(self.data) + + def __getitem__(self, i: int) -> AlignItem: + """Return the i-th item in this object.""" + return self.data[i] + + def append(self, item: AlignItem) -> None: + """Append an item to the end of this object.""" + self.data.append(item) + + def get_decoder_input( + self, + ys: List[int], + context_size: int, + blank_id: int, + ) -> List[List[int]]: + """Get input for the decoder for each item in this object. + + Args: + ys: + The transcript of the utterance in BPE tokens. + context_size: + Context size of the NN decoder model. + blank_id: + The ID of the blank symbol. + Returns: + Return a list-of-list int. `ans[i]` contains the decoder + input for the i-th item in this object and its lengths + is `context_size`. + """ + ans: List[List[int]] = [] + buf = [blank_id] * context_size + ys + for item in self: + # fmt: off + ans.append(buf[item.pos_u:(item.pos_u + context_size)]) + # fmt: on + return ans + + def topk(self, k: int) -> "AlignItemList": + """Return the top-k items. + + Items are ordered by their log probs in descending order + and the top-k items are returned. + + Args: + k: + Size of top-k. + Returns: + Return a new AlignItemList that contains the top-k items + in this object. Caution: It uses shallow copy. + """ + items = list(self) + items = sorted(items, key=lambda i: i.log_prob, reverse=True) + return AlignItemList(items[:k]) + + +def force_alignment( + model: Transducer, + encoder_out: torch.Tensor, + ys: List[int], + beam_size: int = 4, +) -> List[int]: + """Compute the force alignment of an utterance given its transcript + in BPE tokens and the corresponding acoustic output from the encoder. + + Caution: + We assume that the maximum number of sybmols per frame is 1. + That is, the model should be trained using a nonzero value + for the option `--modified-transducer-prob` in train.py. + + Args: + model: + The transducer model. + encoder_out: + A tensor of shape (N, T, C). Support only for N==1 at present. + ys: + A list of BPE token IDs. We require that len(ys) <= T. + beam_size: + Size of the beam used in beam search. + Returns: + Return a list of int such that + - len(ans) == T + - After removing blanks from ans, we have ans == ys. + """ + assert encoder_out.ndim == 3, encoder_out.ndim + assert encoder_out.size(0) == 1, encoder_out.size(0) + assert 0 < len(ys) <= encoder_out.size(1), (len(ys), encoder_out.size(1)) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + T = encoder_out.size(1) + U = len(ys) + assert 0 < U <= T + + encoder_out_len = torch.tensor([1]) + decoder_out_len = encoder_out_len + + start = AlignItem(log_prob=0.0, ys=[], pos_u=0) + B = AlignItemList([start]) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) + # fmt: on + + A = B # shallow copy + B = AlignItemList() + + decoder_input = A.get_decoder_input( + ys=ys, context_size=context_size, blank_id=blank_id + ) + decoder_input = torch.tensor(decoder_input, device=device) + # decoder_input is of shape (num_active_items, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_active_items, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + + # logits is of shape (num_active_items, vocab_size) + log_probs = logits.log_softmax(dim=-1).tolist() + + for i, item in enumerate(A): + if (T - 1 - t) >= (U - item.pos_u): + # horizontal transition (left -> right) + new_item = AlignItem( + log_prob=item.log_prob + log_probs[i][blank_id], + ys=item.ys + [blank_id], + pos_u=item.pos_u, + ) + B.append(new_item) + + if item.pos_u < U: + # diagonal transition (lower left -> upper right) + u = ys[item.pos_u] + new_item = AlignItem( + log_prob=item.log_prob + log_probs[i][u], + ys=item.ys + [u], + pos_u=item.pos_u + 1, + ) + B.append(new_item) + + if len(B) > beam_size: + B = B.topk(beam_size) + + ans = B.topk(1)[0].ys + + assert len(ans) == T + assert list(filter(lambda i: i != blank_id, ans)) == ys + + return ans + + +def get_word_starting_frames( + ali: List[int], sp: spm.SentencePieceProcessor +) -> List[int]: + """Get the starting frame of each word from the given token alignments. + + When a word is encoded into BPE tokens, the first token starts + with underscore "_", which can be used to identify the starting frame + of a word. + + Args: + ali: + Framewise token alignment. It can be the return value of + :func:`force_alignment`. + sp: + The sentencepiece model. + Returns: + Return a list of int representing the starting frame of each word + in the alignment. + Caution: + You have to take into account the model subsampling factor when + converting the starting frame into time. + """ + underscore = b"\xe2\x96\x81".decode() # '_' + ans = [] + for i in range(len(ali)): + if sp.id_to_piece(ali[i]).startswith(underscore): + ans.append(i) + return ans diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py new file mode 100755 index 000000000..48769e9d1 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./transducer_stateless/compute_ali.py \ + --exp-dir ./transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --max-duration 300 \ + --dataset train-clean-100 \ + --out-dir data/ali +""" + +import argparse +import logging +from pathlib import Path +from typing import List + +import numpy as np +import sentencepiece as spm +import torch +from alignment import force_alignment +from asr_datamodule import LibriSpeechAsrDataModule +from lhotse import CutSet +from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--out-dir", + type=str, + required=True, + help="""Output directory. + It contains 2 generated files: + + - token_ali_xxx.h5 + - cuts_xxx.json.gz + + where xxx is the value of `--dataset`. For instance, if + `--dataset` is `train-clean-100`, it will contain 2 files: + + - `token_ali_train-clean-100.h5` + - `cuts_train-clean-100.json.gz` + """, + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean. + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + return parser + + +def compute_alignments( + model: torch.nn.Module, + dl: torch.utils.data, + ali_writer: FeaturesWriter, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +): + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + + device = model.device + cuts = [] + + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + + cut_list = supervisions["cut"] + for cut in cut_list: + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" + + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + + batch_size = encoder_out.size(0) + + texts = supervisions["text"] + + ys_list: List[List[int]] = sp.encode(texts, out_type=int) + + ali_list = [] + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + + ali = force_alignment( + model=model, + encoder_out=encoder_out_i, + ys=ys_list[i], + beam_size=params.beam_size, + ) + ali_list.append(ali) + assert len(ali_list) == len(cut_list) + + for cut, ali in zip(cut_list, ali_list): + cut.token_alignment = ali_writer.store_array( + key=cut.id, + value=np.asarray(ali, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + + cuts += cut_list + + num_cuts += len(cut_list) + + if batch_idx % 2 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return CutSet.from_cuts(cuts) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + args.enable_spec_aug = False + args.enable_musan = False + args.return_cuts = True + args.concatenate_cuts = False + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ali") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"Computing alignments for {params.dataset} - started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + out_dir = Path(params.out_dir) + out_dir.mkdir(exist_ok=True) + + out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5" + out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" + + done_file = out_dir / f".{params.dataset}.done" + if done_file.is_file(): + logging.info(f"{done_file} exists - skipping") + exit() + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to(device) + model.eval() + model.device = device + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) + + logging.info(f"Processing {params.dataset}") + + with NumpyHdf5Writer(out_ali_filename) as ali_writer: + cut_set = compute_alignments( + model=model, + dl=dl, + ali_writer=ali_writer, + params=params, + sp=sp, + ) + + cut_set.to_file(out_manifest_filename) + + logging.info( + f"For dataset {params.dataset}, its framewise token alignments are " + f"saved to {out_ali_filename} and the cut manifest " + f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" + ) + done_file.touch() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py new file mode 100755 index 000000000..99d5b3788 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script shows how to get word starting time +from framewise token alignment. + +Usage: + ./transducer_stateless/compute_ali.py \ + --exp-dir ./transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --max-duration 300 \ + --dataset train-clean-100 \ + --out-dir data/ali + +And the you can run: + + ./transducer_stateless/test_compute_ali.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --ali-dir data/ali \ + --dataset train-clean-100 +""" +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from alignment import get_word_starting_frames +from lhotse import CutSet, load_manifest +from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler +from lhotse.dataset.collation import collate_custom_field + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--ali-dir", + type=Path, + default="./data/ali", + help="It specifies the directory where alignments can be found.", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset: + Possible values are: + - test-clean. + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + cuts_json = args.ali_dir / f"cuts_{args.dataset}.json.gz" + + logging.info(f"Loading {cuts_json}") + cuts = load_manifest(cuts_json) + + sampler = SingleCutSampler( + cuts, + max_duration=30, + shuffle=False, + ) + + dataset = K2SpeechRecognitionDataset(return_cuts=True) + + dl = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=1, + persistent_workers=False, + ) + + frame_shift = 10 # ms + subsampling_factor = 4 + + frame_shift_in_second = frame_shift * subsampling_factor / 1000.0 + + # key: cut.id + # value: a list of pairs (word, time_in_second) + word_starting_time_dict = {} + for batch in dl: + supervisions = batch["supervisions"] + cuts = supervisions["cut"] + + token_alignment, token_alignment_length = collate_custom_field( + CutSet.from_cuts(cuts), "token_alignment" + ) + + for i in range(len(cuts)): + assert ( + (cuts[i].features.num_frames - 1) // 2 - 1 + ) // 2 == token_alignment_length[i] + + word_starting_frames = get_word_starting_frames( + token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp + ) + word_starting_time = [ + "{:.2f}".format(i * frame_shift_in_second) + for i in word_starting_frames + ] + + words = supervisions["text"][i].split() + + assert len(word_starting_frames) == len(words) + word_starting_time_dict[cuts[i].id] = list( + zip(words, word_starting_time) + ) + + # This is a demo script and we exit here after processing + # one batch. + # You can find word starting time in the dict "word_starting_time_dict" + for cut_id, word_time in word_starting_time_dict.items(): + print(f"{cut_id}\n{word_time}\n") + break + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main()