diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py index a1101afe4..c1cd6e3b1 100644 --- a/egs/librispeech/ASR/transducer_stateless/alignment.py +++ b/egs/librispeech/ASR/transducer_stateless/alignment.py @@ -135,7 +135,7 @@ def force_alignment( Caution: We assume that the maximum number of sybmols per frame is 1. - That is, the model should be training using a nonzero value + That is, the model should be trained using a nonzero value for the option `--modified-transducer-prob` in train.py. Args: @@ -163,6 +163,7 @@ def force_alignment( T = encoder_out.size(1) U = len(ys) + assert 0 < U <= T encoder_out_len = torch.tensor([1]) decoder_out_len = encoder_out_len @@ -204,7 +205,7 @@ def force_alignment( for i, item in enumerate(A): if (T - 1 - t) >= (U - item.pos_u): - # horizontal transition + # horizontal transition (left -> right) new_item = AlignItem( log_prob=item.log_prob + log_probs[i][blank_id], ys=item.ys + [blank_id], @@ -213,7 +214,7 @@ def force_alignment( B.append(new_item) if item.pos_u < U: - # diagonal transition + # diagonal transition (lower left -> upper right) u = ys[item.pos_u] new_item = AlignItem( log_prob=item.log_prob + log_probs[i][u], @@ -221,13 +222,14 @@ def force_alignment( pos_u=item.pos_u + 1, ) B.append(new_item) + if len(B) > beam_size: B = B.topk(beam_size) ans = B.topk(1)[0].ys assert len(ans) == T - assert list(filter(lambda i: i != 0, ans)) == ys + assert list(filter(lambda i: i != blank_id, ans)) == ys return ans @@ -235,7 +237,7 @@ def force_alignment( def get_word_starting_frame( ali: List[int], sp: spm.SentencePieceProcessor ) -> List[int]: - """Get the starting frame of each word from the given alignments. + """Get the starting frame of each word from the given token alignments. When a word is encoded into BPE tokens, the first token starts with underscore "_", which can be used to identify the starting frame diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py index eaf102e92..48769e9d1 100755 --- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py +++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py @@ -85,7 +85,7 @@ def get_parser(): type=str, required=True, help="""Output directory. - It contains 3 generated files: + It contains 2 generated files: - token_ali_xxx.h5 - cuts_xxx.json.gz @@ -322,8 +322,5 @@ def main(): done_file.touch() -# torch.set_num_threads(1) -# torch.set_num_interop_threads(1) - if __name__ == "__main__": main()