update the docs of Emformer class in emformer.py

This commit is contained in:
yaozengwei 2022-04-08 10:59:39 +08:00
parent 374eacdd5c
commit 2d1b90f758

View File

@ -1,14 +1,14 @@
import math
from typing import List, Optional, Tuple
import warnings
from typing import List, Optional, Tuple
import torch
from torch import nn
from icefall.utils import make_pad_mask
import torch.nn as nn
from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask
def _get_activation_module(activation: str) -> nn.Module:
if activation == "relu":
@ -1213,12 +1213,12 @@ class Emformer(EncoderInterface):
B: batch size;
D: feature dimension;
U: length of utterance.
T: length of utterance.
Args:
x (torch.Tensor):
Utterance frames right-padded with right context frames,
with shape (B, U + right_context_length, D).
with shape (B, T, D).
x_lens (torch.Tensor):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, containing the
@ -1226,7 +1226,8 @@ class Emformer(EncoderInterface):
Returns:
(Tensor, Tensor):
- output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D).
- output logits, with shape (B, T', D), where
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
- logits lengths, with shape (B,), without containing the
right_context at the end.
"""
@ -1257,12 +1258,12 @@ class Emformer(EncoderInterface):
B: batch size;
D: feature dimension;
U: length of utterance.
T: length of utterance.
Args:
x (torch.Tensor):
Utterance frames right-padded with right context frames,
with shape (B, U + right_context_length, D).
with shape (B, T, D).
lengths (torch.Tensor):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, containing the
@ -1273,7 +1274,8 @@ class Emformer(EncoderInterface):
(default: None)
Returns:
(Tensor, Tensor):
- output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D).
- output logits, with shape (B, T', D), where
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
- logits lengths, with shape (B,), without containing the
right_context at the end.
- updated states from current chunk's computation.