diff --git a/egs/librispeech/ASR/zipformer/ctc_align.py b/egs/librispeech/ASR/zipformer/ctc_align.py index b68d9d589..e0f878d7d 100755 --- a/egs/librispeech/ASR/zipformer/ctc_align.py +++ b/egs/librispeech/ASR/zipformer/ctc_align.py @@ -270,7 +270,7 @@ def align_one_batch( # https://github.com/pytorch/audio/blob/main/src/libtorchaudio/forced_align/gpu/compute.cu#L277 for ii in range(batch_size): labels, log_probs = forced_align( - log_probs=ctc_output[ii, : encoder_out_lens[ii]].unsqueeze(dim=0), + log_probs=ctc_output[ii:ii+1, : encoder_out_lens[ii]], targets=targets[ii, : target_lengths[ii]].unsqueeze(dim=0), input_lengths=encoder_out_lens[ii].unsqueeze(dim=0), target_lengths=target_lengths[ii].unsqueeze(dim=0), @@ -370,7 +370,7 @@ def align_dataset( except TypeError: num_batches = "?" - ignored_tokens = params.ignored_tokens + ["", ""] + ignored_tokens = set(params.ignored_tokens + ["", ""]) ignored_tokens_ints = [sp.piece_to_id(token) for token in ignored_tokens] logging.info(f"ignored tokens {ignored_tokens}")