diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index edba2e0b3..91bb571c5 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -1,14 +1,14 @@ import math -from typing import List, Optional, Tuple import warnings +from typing import List, Optional, Tuple import torch -from torch import nn - -from icefall.utils import make_pad_mask +import torch.nn as nn from encoder_interface import EncoderInterface from subsampling import Conv2dSubsampling, VggSubsampling +from icefall.utils import make_pad_mask + def _get_activation_module(activation: str) -> nn.Module: if activation == "relu": @@ -1213,12 +1213,12 @@ class Emformer(EncoderInterface): B: batch size; D: feature dimension; - U: length of utterance. + T: length of utterance. Args: x (torch.Tensor): Utterance frames right-padded with right context frames, - with shape (B, U + right_context_length, D). + with shape (B, T, D). x_lens (torch.Tensor): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, containing the @@ -1226,7 +1226,8 @@ class Emformer(EncoderInterface): Returns: (Tensor, Tensor): - - output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D). + - output logits, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. - logits lengths, with shape (B,), without containing the right_context at the end. """ @@ -1257,12 +1258,12 @@ class Emformer(EncoderInterface): B: batch size; D: feature dimension; - U: length of utterance. + T: length of utterance. Args: x (torch.Tensor): Utterance frames right-padded with right context frames, - with shape (B, U + right_context_length, D). + with shape (B, T, D). lengths (torch.Tensor): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, containing the @@ -1273,7 +1274,8 @@ class Emformer(EncoderInterface): (default: None) Returns: (Tensor, Tensor): - - output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D). + - output logits, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. - logits lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation.