diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 673053c55..3abcf588e 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -241,34 +241,6 @@ def ctc_greedy_search( return hyps, scores -def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: - """Make mask tensor containing indices of padded part. - - See description of make_non_pad_mask. - - Args: - lengths (torch.Tensor): Batch of lengths (B,). - Returns: - torch.Tensor: Mask tensor containing indices of padded part. - - Examples: - >>> lengths = [5, 3, 2] - >>> make_pad_mask(lengths) - masks = [[0, 0, 0, 0 ,0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1]] - """ - batch_size = lengths.size(0) - max_len = max_len if max_len > 0 else lengths.max().item() - seq_range = torch.arange( - 0, max_len, dtype=torch.int64, device=lengths.device - ) - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - seq_length_expand = lengths.unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - return mask - - def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: new_hyp: List[int] = [] cur = 0