update the docs of Emformer class in emformer.py

This commit is contained in:
yaozengwei 2022-04-08 10:59:39 +08:00
parent 374eacdd5c
commit 2d1b90f758

View File

@ -1,14 +1,14 @@
import math import math
from typing import List, Optional, Tuple
import warnings import warnings
from typing import List, Optional, Tuple
import torch import torch
from torch import nn import torch.nn as nn
from icefall.utils import make_pad_mask
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask
def _get_activation_module(activation: str) -> nn.Module: def _get_activation_module(activation: str) -> nn.Module:
if activation == "relu": if activation == "relu":
@ -1213,12 +1213,12 @@ class Emformer(EncoderInterface):
B: batch size; B: batch size;
D: feature dimension; D: feature dimension;
U: length of utterance. T: length of utterance.
Args: Args:
x (torch.Tensor): x (torch.Tensor):
Utterance frames right-padded with right context frames, Utterance frames right-padded with right context frames,
with shape (B, U + right_context_length, D). with shape (B, T, D).
x_lens (torch.Tensor): x_lens (torch.Tensor):
With shape (B,) and i-th element representing number of valid With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, containing the utterance frames for i-th batch element in x, containing the
@ -1226,7 +1226,8 @@ class Emformer(EncoderInterface):
Returns: Returns:
(Tensor, Tensor): (Tensor, Tensor):
- output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D). - output logits, with shape (B, T', D), where
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
- logits lengths, with shape (B,), without containing the - logits lengths, with shape (B,), without containing the
right_context at the end. right_context at the end.
""" """
@ -1257,12 +1258,12 @@ class Emformer(EncoderInterface):
B: batch size; B: batch size;
D: feature dimension; D: feature dimension;
U: length of utterance. T: length of utterance.
Args: Args:
x (torch.Tensor): x (torch.Tensor):
Utterance frames right-padded with right context frames, Utterance frames right-padded with right context frames,
with shape (B, U + right_context_length, D). with shape (B, T, D).
lengths (torch.Tensor): lengths (torch.Tensor):
With shape (B,) and i-th element representing number of valid With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, containing the utterance frames for i-th batch element in x, containing the
@ -1273,7 +1274,8 @@ class Emformer(EncoderInterface):
(default: None) (default: None)
Returns: Returns:
(Tensor, Tensor): (Tensor, Tensor):
- output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D). - output logits, with shape (B, T', D), where
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
- logits lengths, with shape (B,), without containing the - logits lengths, with shape (B,), without containing the
right_context at the end. right_context at the end.
- updated states from current chunk's computation. - updated states from current chunk's computation.