From bcc5923ab92c93bf38829f7d5de84d84c9050eb1 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Tue, 28 Mar 2023 23:24:24 +0800 Subject: [PATCH] Support batch-wise forced-alignment (#970) * support batch-wise forced-alignment based on beam search * add length_norm to HypothesisList.topk() * Use Hypothesis and HypothesisList instead --- .../beam_search.py | 17 +- .../pruned_transducer_stateless7/alignment.py | 206 +++++++++++ .../compute_ali.py | 345 ++++++++++++++++++ .../test_compute_ali.py | 130 +++++++ icefall/utils.py | 2 +- 5 files changed, 696 insertions(+), 4 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index bd2d6e258..999d793a4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -829,11 +829,22 @@ class HypothesisList(object): ans.add(hyp) # shallow copy return ans - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ hyps = list(self._data.items()) - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] ans = HypothesisList(dict(hyps)) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py new file mode 100644 index 000000000..76cd56bbb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -0,0 +1,206 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +import k2 +import torch + +from beam_search import Hypothesis, HypothesisList, get_hyps_shape + +# The force alignment problem can be formulated as finding +# a path in a rectangular lattice, where the path starts +# from the lower left corner and ends at the upper right +# corner. The horizontal axis of the lattice is `t` (representing +# acoustic frame indexes) and the vertical axis is `u` (representing +# BPE tokens of the transcript). +# +# The notations `t` and `u` are from the paper +# https://arxiv.org/pdf/1211.3711.pdf +# +# Beam search is used to find the path with the highest log probabilities. +# +# It assumes the maximum number of symbols that can be +# emitted per frame is 1. + + +def batch_force_alignment( + model: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_list: List[List[int]], + beam_size: int = 4, +) -> List[int]: + """Compute the force alignment of a batch of utterances given their transcripts + in BPE tokens and the corresponding acoustic output from the encoder. + + Caution: + This function is modified from `modified_beam_search` in beam_search.py. + We assume that the maximum number of sybmols per frame is 1. + + Args: + model: + The transducer model. + encoder_out: + A tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + ys_list: + A list of BPE token IDs list. We require that for each utterance i, + len(ys_list[i]) <= encoder_out_lens[i]. + beam_size: + Size of the beam used in beam search. + + Returns: + Return a list of frame indexes list for each utterance i, + where len(ans[i]) == len(ys_list[i]). + """ + assert encoder_out.ndim == 3, encoder_out.ndim + assert encoder_out.size(0) == len(ys_list), (encoder_out.size(0), len(ys_list)) + assert encoder_out.size(0) > 0, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + sorted_indices = packed_encoder_out.sorted_indices.tolist() + encoder_out_lens = encoder_out_lens.tolist() + ys_lens = [len(ys) for ys in ys_list] + sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices] + sorted_ys_lens = [ys_lens[i] for i in sorted_indices] + sorted_ys_list = [ys_list[i] for i in sorted_indices] + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size] + sorted_ys_lens = sorted_ys_lens[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs.reshape(-1) + ) # [batch][num_hyps*vocab_size] + + for i in range(batch_size): + for h, hyp in enumerate(A[i]): + pos_u = len(hyp.timestamp) + idx_offset = h * vocab_size + if (sorted_encoder_out_lens[i] - 1 - t) >= (sorted_ys_lens[i] - pos_u): + # emit blank token + new_hyp = Hypothesis( + log_prob=ragged_log_probs[i][idx_offset + blank_id], + ys=hyp.ys[:], + timestamp=hyp.timestamp[:], + ) + B[i].add(new_hyp) + if pos_u < sorted_ys_lens[i]: + # emit non-blank token + new_token = sorted_ys_list[i][pos_u] + new_hyp = Hypothesis( + log_prob=ragged_log_probs[i][idx_offset + new_token], + ys=hyp.ys + [new_token], + timestamp=hyp.timestamp + [t], + ) + B[i].add(new_hyp) + + if len(B[i]) > beam_size: + B[i] = B[i].topk(beam_size, length_norm=True) + + B = B + finalized_B + sorted_hyps = [b.get_most_probable() for b in B] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + hyps = [sorted_hyps[i] for i in unsorted_indices] + ans = [] + for i, hyp in enumerate(hyps): + assert hyp.ys[context_size:] == ys_list[i], (hyp.ys[context_size:], ys_list[i]) + ans.append(hyp.timestamp) + + return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py new file mode 100755 index 000000000..8bcb56d62 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The script gets forced-alignments based on the modified_beam_search decoding method. +Both token-level alignments and word-level alignments are saved to the new cuts manifests. + +It loads a checkpoint and uses it to get the forced-alignments. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +./pruned_transducer_stateless7/compute_ali.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --dataset test-clean \ + --max-duration 300 \ + --beam-size 4 \ + --cuts-out-dir data/fbank_ali_beam_search +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from alignment import batch_force_alignment +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp +from lhotse import CutSet +from lhotse.serialization import SequentialJsonlWriter +from lhotse.supervision import AlignmentItem + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--cuts-out-dir", + type=str, + default="data/fbank_ali_beam_search", + help="The dir to save the new cuts manifests with alignments", + ) + + add_model_arguments(parser) + + return parser + + +def align_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> Tuple[List[List[str]], List[List[str]], List[List[float]], List[List[float]]]: + """Get forced-alignments for one batch. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + token_list: + A list of token list. + word_list: + A list of word list. + token_time_list: + A list of timestamps list for tokens. + word_time_list. + A list of timestamps list for words. + + where len(token_list) == len(word_list) == len(token_time_list) == len(word_time_list), + len(token_list[i]) == len(token_time_list[i]), + and len(word_list[i]) == len(word_time_list[i]) + + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + texts = supervisions["text"] + ys_list: List[List[int]] = sp.encode(texts, out_type=int) + + frame_indexes = batch_force_alignment( + model, encoder_out, encoder_out_lens, ys_list, params.beam_size + ) + + token_list = [] + word_list = [] + token_time_list = [] + word_time_list = [] + for i in range(encoder_out.size(0)): + tokens = sp.id_to_piece(ys_list[i]) + words = texts[i].split() + token_time = convert_timestamp( + frame_indexes[i], params.subsampling_factor, params.frame_shift_ms + ) + word_time = parse_timestamp(tokens, token_time) + assert len(word_time) == len(words), (len(word_time), len(words)) + + token_list.append(tokens) + word_list.append(words) + token_time_list.append(token_time) + word_time_list.append(word_time) + + return token_list, word_list, token_time_list, word_time_list + + +def align_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + writer: SequentialJsonlWriter, +) -> None: + """Get forced-alignments for the dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + writer: + Writer to save the cuts with alignments. + """ + log_interval = 20 + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + token_list, word_list, token_time_list, word_time_list = align_one_batch( + params=params, model=model, sp=sp, batch=batch + ) + + cut_list = batch["supervisions"]["cut"] + for cut, token, word, token_time, word_time in zip( + cut_list, token_list, word_list, token_time_list, word_time_list + ): + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" + token_ali = [ + AlignmentItem( + symbol=token[i], + start=round(token_time[i], ndigits=3), + duration=None, + ) + for i in range(len(token)) + ] + word_ali = [ + AlignmentItem( + symbol=word[i], start=round(word_time[i], ndigits=3), duration=None + ) + for i in range(len(word)) + ] + cut.supervisions[0].alignment = {"word": word_ali, "token": token_ali} + writer.write(cut, flush=True) + + num_cuts += len(cut_list) + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) + + cuts_out_dir = Path(params.cuts_out_dir) + cuts_out_dir.mkdir(parents=True, exist_ok=True) + cuts_out_path = cuts_out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + + with CutSet.open_writer(cuts_out_path) as writer: + align_dataset(dl=dl, params=params, model=model, sp=sp, writer=writer) + + logging.info( + f"For dataset {params.dataset}, the cut manifest with framewise token alignments " + f"and word alignments are saved to {cuts_out_path}" + ) + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py new file mode 100755 index 000000000..081f7ba1a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script compares the word-level alignments generated based on modified_beam_search decoding +(in ./pruned_transducer_stateless7/compute_ali.py) to the reference alignments generated +by torchaudio framework (in ./add_alignments.sh). + +Usage: + +./pruned_transducer_stateless7/compute_ali.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --dataset test-clean \ + --max-duration 300 \ + --beam-size 4 \ + --cuts-out-dir data/fbank_ali_beam_search + +And the you can run: + +./pruned_transducer_stateless7/test_compute_ali.py \ + --cuts-out-dir ./data/fbank_ali_test \ + --cuts-ref-dir ./data/fbank_ali_torch \ + --dataset train-clean-100 +""" +import argparse +import logging +from pathlib import Path + +import torch +from lhotse import load_manifest + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--cuts-out-dir", + type=Path, + default="./data/fbank_ali", + help="The dir that saves the generated cuts manifests with alignments", + ) + + parser.add_argument( + "--cuts-ref-dir", + type=Path, + default="./data/fbank_ali_torch", + help="The dir that saves the reference cuts manifests with alignments", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset: + Possible values are: + - test-clean + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + + cuts_out_jsonl = args.cuts_out_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz" + cuts_ref_jsonl = args.cuts_ref_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz" + + logging.info(f"Loading {cuts_out_jsonl} and {cuts_ref_jsonl}") + cuts_out = load_manifest(cuts_out_jsonl) + cuts_ref = load_manifest(cuts_ref_jsonl) + cuts_ref = cuts_ref.sort_like(cuts_out) + + all_time_diffs = [] + for cut_out, cut_ref in zip(cuts_out, cuts_ref): + time_out = [ + ali.start + for ali in cut_out.supervisions[0].alignment["word"] + if ali.symbol != "" + ] + time_ref = [ + ali.start + for ali in cut_ref.supervisions[0].alignment["word"] + if ali.symbol != "" + ] + assert len(time_out) == len(time_ref), (len(time_out), len(time_ref)) + diff = [ + round(abs(out - ref), ndigits=3) for out, ref in zip(time_out, time_ref) + ] + all_time_diffs += diff + + all_time_diffs = torch.tensor(all_time_diffs) + logging.info( + f"For the word-level alignments abs difference on dataset {args.dataset}, " + f"mean: {'%.2f' % all_time_diffs.mean()}s, std: {'%.2f' % all_time_diffs.std()}s" + ) + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/utils.py b/icefall/utils.py index 5d86472b5..1fd9156bd 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1378,7 +1378,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: List of timestamp of each word. """ start_token = b"\xe2\x96\x81".decode() # '_' - assert len(tokens) == len(timestamp) + assert len(tokens) == len(timestamp), (len(tokens), len(timestamp)) ans = [] for i in range(len(tokens)): flag = False