From d50e7734a6ec273695f0ac7d88fd6b64f46c2ae2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Mar 2022 12:57:11 +0800 Subject: [PATCH] Add more documentation. --- .../ASR/transducer_stateless/alignment.py | 129 ++++++++++++++---- .../ASR/transducer_stateless/compute_ali.py | 68 +++++---- 2 files changed, 144 insertions(+), 53 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py index 60c737163..492a6fc51 100644 --- a/egs/librispeech/ASR/transducer_stateless/alignment.py +++ b/egs/librispeech/ASR/transducer_stateless/alignment.py @@ -16,63 +16,88 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Iterator, List, Optional +import sentencepiece as spm import torch from model import Transducer -# TODO(fangjun): Add more documentation - -# The force alignment problem can be formulated as find +# The force alignment problem can be formulated as finding # a path in a rectangular lattice, where the path starts # from the lower left corner and ends at the upper right -# corner. The horizontal axis of the lattice is `t` -# and the vertical axis is `u`. +# corner. The horizontal axis of the lattice is `t` (representing +# acoustic frame indexes) and the vertical axis is `u` (representing +# BPE tokens of the transcript). # +# Beam search is used to find the path that with the +# highest log probabilities. +# +# It assumes that the maximum number of symbols that can be +# emitted per frame is 1. You can use `--modified-transducer-prob` +# from train.py to train a model that satisfy this assumption. + + # AlignItem is a node in the lattice, where its # len(ys) equals to `t` and pos_u is the u coordinate # in the lattice. @dataclass class AlignItem: + # log prob of this log_prob: float + + # It contains framewise token alignment ys: List[int] + + # It equals to number of non-zero entries in ys pos_u: int class AlignItemList: def __init__(self, items: Optional[List[AlignItem]] = None): + """ + Args: + items: + A list of AlignItem + """ if items is None: items = [] self.data = items - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.data) - def __len__(self): + def __len__(self) -> int: + """Return the number of AlignItem in this object.""" return len(self.data) def __getitem__(self, i: int) -> AlignItem: + """Return the i-th item in this object.""" return self.data[i] def append(self, item: AlignItem) -> None: + """Append an item to the end of this object.""" self.data.append(item) - def get_active_items(self, T: int, U: int) -> "AlignItemList": - ans = [] - for item in self: - t = len(item.ys) - if U - item.pos_u > T - t: - continue - ans.append(item) - - return AlignItemList(ans) - def get_decoder_input( self, ys: List[int], context_size: int, blank_id: int, ) -> List[List[int]]: + """Get input for the decoder for each item in this object. + + Args: + ys: + The transcript of the utterance in BPE tokens. + context_size: + Context size of the NN decoder model. + blank_id: + The ID of the blank symbol. + Returns: + Return a list-of-list int. `ans[i]` contains the decoder + input for the i-th item in this object and its lengths + is `context_size`. + """ ans: List[List[int]] = [] buf = [blank_id] * context_size + ys for item in self: @@ -82,6 +107,18 @@ class AlignItemList: return ans def topk(self, k: int) -> "AlignItemList": + """Return the top-k items. + + Items are ordered by their log probs in descending order + and the top-k items are returned. + + Args: + k: + Size of top-k. + Returns: + Return a new AlignItemList that contains the top-k items + in this object. Caution: It uses shallow copy. + """ items = list(self) items = sorted(items, key=lambda i: i.log_prob, reverse=True) return AlignItemList(items[:k]) @@ -93,24 +130,28 @@ def force_alignment( ys: List[int], beam_size: int = 4, ) -> List[int]: - """ + """Compute the force alignment of an utterance given its transcript + in BPE tokens and the corresponding acoustic output from the encoder. + + Caution: + We assume that the maximum number of sybmols per frame is 1. + That is, the model should be training using a nonzero value + for the option `--modified-transducer-prob` in train.py. + Args: model: The transducer model. encoder_out: - A tensor of shape (N, T, C). Support only for N==1 now. + A tensor of shape (N, T, C). Support only for N==1 at present. ys: - A list of token IDs. We require that len(ys) <= T. - beam: + A list of BPE token IDs. We require that len(ys) <= T. + beam_size: Size of the beam used in beam search. Returns: Return a list of int such that - len(ans) == T - After removing blanks from ans, we have ans == ys. """ - import pdb - - pdb.set_trace() assert encoder_out.ndim == 3, encoder_out.ndim assert encoder_out.size(0) == 1, encoder_out.size(0) assert 0 < len(ys) <= encoder_out.size(1), (len(ys), encoder_out.size(1)) @@ -135,7 +176,6 @@ def force_alignment( # current_encoder_out is of shape (1, 1, encoder_out_dim) # fmt: on - # A = B.get_active_items() A = B # shallow copy B = AlignItemList() @@ -184,4 +224,39 @@ def force_alignment( if len(B) > beam_size: B = B.topk(beam_size) - return B.topk(1)[0].ys + ans = B.topk(1)[0].ys + + assert len(ans) == T + assert list(filter(lambda i: i != 0, ans)) == ys + + return ans + + +def get_word_begin_frame( + ali: List[int], sp: spm.SentencePieceProcessor +) -> List[int]: + """Get the beginning of each word from the given alignments. + + When a word is encoded into BPE tokens, the first token starts + with underscore "_", which can be used to identify the beginning + of a word. + + Args: + ali: + Framewise token alignment. It can be the return value of + :func:`force_alignment`. + sp: + The sentencepiece model. + Returns: + Return a list of int representing the starting frame of each word + in the alignment. + Caution: + You have to take into account the model subsampling factor when + converting the starting frame into time. + """ + underscore = b"\xe2\x96\x81".decode() # '_' + ans = [] + for i in range(len(ali)): + if sp.id_to_piece(ali[i]).startswith(underscore): + ans.append(i) + return ans diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py index 4813154b2..dd0665326 100755 --- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py +++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py @@ -137,20 +137,10 @@ def get_parser(): return parser -def get_word_begin_time(ali: List[int], sp: spm.SentencePieceProcessor): - underscore = b"\xe2\x96\x81".decode() # '_' - ans = [] - for i in range(len(ali)): - print(sp.id_to_piece(ali[i])) - if sp.id_to_piece(ali[i]).startswith(underscore): - print("yes") - ans.append(i * 0.04) - return ans - - def compute_alignments( model: torch.nn.Module, dl: torch.utils.data, + ali_writer: FeaturesWriter, params: AttributeDict, sp: spm.SentencePieceProcessor, ): @@ -202,10 +192,30 @@ def compute_alignments( beam_size=params.beam_size, ) ali_list.append(ali) - word_begin_time_list.append(get_word_begin_time(ali, sp)) - import pdb + assert len(ali_list) == len(cut_list) - pdb.set_trace() + for cut, ali in zip(cut_list, ali_list): + cut.token_alignment = ali_writer.store_array( + key=cut.id, + value=np.asarray(ali, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + + cuts += cut_list + + num_cuts += len(cut_list) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return CutSet.from_cuts(cuts) @torch.no_grad() @@ -242,13 +252,11 @@ def main(): out_dir = Path(params.out_dir) out_dir.mkdir(exist_ok=True) - out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" - out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" + out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5" out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" for f in ( - out_labels_ali_filename, - out_aux_labels_ali_filename, + out_ali_filename, out_manifest_filename, ): if f.exists(): @@ -305,18 +313,26 @@ def main(): logging.info(f"Processing {params.dataset}") - cut_set = compute_alignments( - model=model, - dl=dl, - # labels_writer=labels_writer, - # aux_labels_writer=aux_labels_writer, - params=params, - sp=sp, + with NumpyHdf5Writer(out_ali_filename) as ali_writer: + cut_set = compute_alignments( + model=model, + dl=dl, + ali_writer=ali_writer, + params=params, + sp=sp, + ) + + cut_set.to_file(out_manifest_filename) + + logging.info( + f"For dataset {params.dataset}, its framewise token alignments are " + f"saved to {out_ali_filename} and the cut manifest " + f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" ) -# torch.set_num_interop_threads(1) # torch.set_num_threads(1) +# torch.set_num_interop_threads(1) if __name__ == "__main__": main()