From 630626a092234f65bd827b4d7689da7bc791c84d Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 26 Jun 2022 21:39:06 +0800 Subject: [PATCH] support position encoding --- .flake8 | 3 +- .../emformer.py | 386 ++++++++++++++++-- .../test_emformer.py | 44 +- 3 files changed, 394 insertions(+), 39 deletions(-) diff --git a/.flake8 b/.flake8 index 9dd8d6207..d67fc1542 100644 --- a/.flake8 +++ b/.flake8 @@ -9,8 +9,7 @@ per-file-ignores = egs/*/ASR/pruned_transducer_stateless*/*.py: E501, egs/*/ASR/*/optim.py: E501, egs/*/ASR/*/scaling.py: E501, - egs/librispeech/ASR/conv_emformer_transducer_stateless/*.py: E501, E203 - egs/librispeech/ASR/conv_emformer_transducer_stateless2/*.py: E501, E203 + egs/librispeech/ASR/conv_emformer_transducer_stateless*/emformer.py: E501, E203 # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless3/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless3/emformer.py index e3a598b0e..06fc880df 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless3/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless3/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -434,6 +433,10 @@ class EmformerAttention(nn.Module): r"""Emformer layer attention module. Args: + chunk_length (int): + Length of chunk. + right_context_length (int): + Length of right context. embed_dim (int): Embedding dimension. nhead (int): @@ -448,6 +451,8 @@ class EmformerAttention(nn.Module): def __init__( self, + chunk_length: int, + right_context_length: int, embed_dim: int, nhead: int, dropout: float = 0.0, @@ -455,6 +460,8 @@ class EmformerAttention(nn.Module): negative_inf: float = -1e8, ): super().__init__() + self.chunk_length = chunk_length + self.right_context_length = right_context_length if embed_dim % nhead != 0: raise ValueError( @@ -477,6 +484,26 @@ class EmformerAttention(nn.Module): embed_dim, embed_dim, bias=True, initial_scale=0.25 ) + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa + self.pos_bias_u = nn.Parameter(torch.Tensor(nhead, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(nhead, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + def _gen_attention_probs( self, attention_weights: torch.Tensor, @@ -539,6 +566,8 @@ class EmformerAttention(nn.Module): right_context: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, + pos_emb: torch.Tensor, + rel_pos: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, left_context_key: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None, @@ -556,26 +585,79 @@ class EmformerAttention(nn.Module): torch.cat([memory, right_context, utterance]) ).chunk(chunks=2, dim=2) + is_streaming_infer = False if left_context_key is not None and left_context_val is not None: # now compute key and value with # [memory, right context, left context, uttrance] - # this is used in inference mode + # this is used in streaming inference mode + is_streaming_infer = True key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) value = torch.cat( [value[: M + R], left_context_val, value[M + R :]] ) Q = query.size(0) - # KV = key.size(0) + KV = key.size(0) - reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) + reshaped_key = ( + key.contiguous() + .view(KV, B, self.nhead, self.head_dim) + .permute(1, 2, 0, 3) + ) # (B, nhead, KV, head_dim) + reshaped_value = ( + value.contiguous() + .view(KV, B * self.nhead, self.head_dim) .transpose(0, 1) - for tensor in [query, key, value] - ] # (B * nhead, Q or KV, head_dim) - attention_weights = torch.bmm( - reshaped_query * scaling, reshaped_key.transpose(1, 2) - ) # (B * nhead, Q, KV) + ) # (B * nhead, KV, head_dim) + query = ( + (query * scaling).contiguous().view(Q, B, self.nhead, self.head_dim) + ) + # (B, nhead, Q, head_dim) + query_with_bias_u = (query + self._pos_bias_u()).permute(1, 2, 0, 3) + query_with_bias_v = (query + self._pos_bias_v()).permute(1, 2, 0, 3) + + PE = pos_emb.size(0) + # pos_emb contains flipped positive part and negative part + # for relative position i - j between query (i) and key (j) + if is_streaming_infer: + # i is the first frame in current chunk (query) + # j is the last frame in right context (key) + # Note: R is equal to self.right_context_length here + min_neg_abs = U + R - 1 + # i is the last frame in right context (query) + # j is the first frame in the past context that memory bank can cover (key) # noqa + max_pos_abs = U + R + M * self.chunk_length - 1 + else: + # i is the first frame in utterance (query) + # j is the last frame in the last chunk's right context (key) + min_neg_abs = U + self.right_context_length - 1 + # i is the last frame in the last chunk's right context (query) + # j is the first frame in the utterance (key) + max_pos_abs = U + self.right_context_length - 1 + assert PE == min_neg_abs + max_pos_abs + 1 + pos_emb = ( + self.linear_pos(pos_emb) + .view(1, PE, self.nhead, self.head_dim) + .transpose(1, 2) + ) # (1, nhead, PE, head_dim) + + # content-based matrix-ac + matrix_ac = torch.matmul( + query_with_bias_u, reshaped_key.transpose(-2, -1) + ) # (B, nhead, Q, KV) + + # position-based matrix-bd + # (B, nhead, Q, PE) + matrix_bd = torch.matmul(query_with_bias_v, pos_emb.transpose(-2, -1)) + # gather position-related scores using pre-computed relative position + assert rel_pos.shape == (Q, KV) + rel_pos = rel_pos.unsqueeze(0).unsqueeze(1).expand(B, self.nhead, Q, KV) + matrix_bd = torch.gather( + matrix_bd, + dim=-1, + index=rel_pos, + ) # (B, nhead, Q, KV) + + attention_weights = (matrix_ac + matrix_bd).view(B * self.nhead, Q, KV) # compute attention probabilities attention_probs = self._gen_attention_probs( @@ -600,6 +682,8 @@ class EmformerAttention(nn.Module): right_context: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, + pos_emb: torch.Tensor, + rel_pos: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: # TODO: Modify docs. @@ -647,6 +731,11 @@ class EmformerAttention(nn.Module): attention_mask (torch.Tensor): Pre-computed attention mask to simulate underlying chunk-wise attention, with shape (Q, KV). + pos_emb (torch.Tensor): + Position embedding, with shape (PE, D), + where PE = 2 * (U + right_context_length) - 1. + rel_pos (torch.Tensor): + Relative positions, with shape (Q, KV). padding_mask (torch.Tensor): Padding mask of key tensor, with shape (B, KV). @@ -654,10 +743,12 @@ class EmformerAttention(nn.Module): Output of right context and utterance, with shape (R + U, B, D). """ output_right_context_utterance, _, _ = self._forward_impl( - utterance, - right_context, - memory, - attention_mask, + utterance=utterance, + right_context=right_context, + memory=memory, + attention_mask=attention_mask, + pos_emb=pos_emb, + rel_pos=rel_pos, padding_mask=padding_mask, ) return output_right_context_utterance @@ -667,6 +758,8 @@ class EmformerAttention(nn.Module): self, utterance: torch.Tensor, right_context: torch.Tensor, + pos_emb: torch.Tensor, + rel_pos: torch.Tensor, memory: torch.Tensor, left_context_key: torch.Tensor, left_context_val: torch.Tensor, @@ -700,6 +793,11 @@ class EmformerAttention(nn.Module): right_context (torch.Tensor): Right context frames, with shape (R, B, D), where R = right_context_length. + pos_emb (torch.Tensor): + Position embedding, with shape (PE, D), + where PE = 2 * (U + R) + M * chunk_length - 1. + rel_pos (torch.Tensor): + Relative positions, with shape (Q, KV). memory (torch.Tensor): Memory vectors, with shape (M, B, D), or empty tensor. left_context_key (torch,Tensor): @@ -733,10 +831,12 @@ class EmformerAttention(nn.Module): ) output_right_context_utterance, key, value = self._forward_impl( - utterance, - right_context, - memory, - attention_mask, + utterance=utterance, + right_context=right_context, + memory=memory, + attention_mask=attention_mask, + pos_emb=pos_emb, + rel_pos=rel_pos, padding_mask=padding_mask, left_context_key=left_context_key, left_context_val=left_context_val, @@ -796,6 +896,8 @@ class EmformerEncoderLayer(nn.Module): super().__init__() self.attention = EmformerAttention( + chunk_length=chunk_length, + right_context_length=right_context_length, embed_dim=d_model, nhead=nhead, dropout=dropout, @@ -898,6 +1000,8 @@ class EmformerEncoderLayer(nn.Module): right_context_utterance: torch.Tensor, R: int, attention_mask: torch.Tensor, + pos_emb: torch.Tensor, + rel_pos: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Apply attention module in training and validation mode.""" @@ -917,15 +1021,18 @@ class EmformerEncoderLayer(nn.Module): right_context=right_context, memory=memory, attention_mask=attention_mask, + pos_emb=pos_emb, + rel_pos=rel_pos, padding_mask=padding_mask, ) - return output_right_context_utterance def _apply_attention_module_infer( self, right_context_utterance: torch.Tensor, R: int, + pos_emb: torch.Tensor, + rel_pos: torch.Tensor, attn_cache: List[torch.Tensor], padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: @@ -962,6 +1069,8 @@ class EmformerEncoderLayer(nn.Module): ) = self.attention.infer( utterance=utterance, right_context=right_context, + pos_emb=pos_emb, + rel_pos=rel_pos, memory=pre_memory, left_context_key=left_context_key, left_context_val=left_context_val, @@ -977,6 +1086,8 @@ class EmformerEncoderLayer(nn.Module): utterance: torch.Tensor, right_context: torch.Tensor, attention_mask: torch.Tensor, + pos_emb: torch.Tensor, + rel_pos: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, warmup: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -996,6 +1107,11 @@ class EmformerEncoderLayer(nn.Module): attention_mask (torch.Tensor): Attention mask for underlying attention module, with shape (Q, KV), where Q = R + U, KV = M + R + U. + pos_emb (torch.Tensor): + Position embedding, with shape (PE, D), + where PE = 2 * (U + right_context_length) - 1. + rel_pos (torch.Tensor): + Relative positions, with shape (Q, KV). padding_mask (torch.Tensor): Padding mask of ker tensor, with shape (B, KV). @@ -1025,7 +1141,12 @@ class EmformerEncoderLayer(nn.Module): # emformer attention module src_att = self._apply_attention_module_forward( - src, R, attention_mask, padding_mask=padding_mask + right_context_utterance=src, + R=R, + attention_mask=attention_mask, + pos_emb=pos_emb, + rel_pos=rel_pos, + padding_mask=padding_mask, ) src = src + self.dropout(src_att) @@ -1050,6 +1171,8 @@ class EmformerEncoderLayer(nn.Module): self, utterance: torch.Tensor, right_context: torch.Tensor, + pos_emb: torch.Tensor, + rel_pos: torch.Tensor, attn_cache: List[torch.Tensor], conv_cache: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, @@ -1067,6 +1190,12 @@ class EmformerEncoderLayer(nn.Module): Utterance frames, with shape (U, B, D). right_context (torch.Tensor): Right context frames, with shape (R, B, D). + pos_emb (torch.Tensor): + Position embedding, with shape (PE, D), + where PE = 2 * (U + R) + M * chunk_length - 1. + rel_pos (torch.Tensor): + Relative positions, with shape (Q, KV), + where Q = R + U, KV = M + R + L + U. attn_cache (List[torch.Tensor]): Cached attention tensors generated in preceding computation, including memory, key and value of left context. @@ -1090,7 +1219,12 @@ class EmformerEncoderLayer(nn.Module): # emformer attention module src_att, attn_cache = self._apply_attention_module_infer( - src, R, attn_cache, padding_mask=padding_mask + right_context_utterance=src, + R=R, + pos_emb=pos_emb, + rel_pos=rel_pos, + attn_cache=attn_cache, + padding_mask=padding_mask, ) src = src + self.dropout(src_att) @@ -1187,6 +1321,7 @@ class EmformerEncoder(nn.Module): self.use_memory = memory_size > 0 + self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.emformer_layers = nn.ModuleList( [ EmformerEncoderLayer( @@ -1215,7 +1350,9 @@ class EmformerEncoder(nn.Module): self.memory_size = memory_size self.cnn_module_kernel = cnn_module_kernel - def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: + def _gen_right_context( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] num_chunks = math.ceil( @@ -1235,9 +1372,9 @@ class EmformerEncoder(nn.Module): indexes, torch.arange(T - self.right_context_length, T).unsqueeze(0), ] - ) - right_context_blocks = x[indexes.reshape(-1)] - return right_context_blocks + ).reshape(-1) + right_context_blocks = x[indexes] + return right_context_blocks, indexes def _gen_attention_mask_col_widths( self, chunk_idx: int, U: int @@ -1381,10 +1518,33 @@ class EmformerEncoder(nn.Module): - output_lengths, with shape (B,), without containing the right_context at the end. """ - U = x.size(0) - self.right_context_length + x, pos_emb = self.encoder_pos(x, pos_len=x.size(0), neg_len=x.size(0)) - right_context = self._gen_right_context(x) + U = x.size(0) - self.right_context_length + right_context, right_context_indexes = self._gen_right_context(x) + utterance_indexes = torch.arange(0, U) utterance = x[:U] + num_chunks = math.ceil(U / self.chunk_length) + memory_indexes = ( + torch.arange( + self.chunk_length // 2, + (num_chunks - 1) * self.chunk_length, + self.chunk_length, + ) + if num_chunks > 1 + else torch.empty(0).to(dtype=utterance_indexes.dtype) + ) + query_indexes = torch.cat( + [right_context_indexes, utterance_indexes] + ).to(device=x.device) + key_indexes = torch.cat( + [memory_indexes, right_context_indexes, utterance_indexes] + ).to(device=x.device) + # calculate relative position and flip sign + rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)) + # shift to start from zero + rel_pos = rel_pos - rel_pos.min() + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) @@ -1394,9 +1554,11 @@ class EmformerEncoder(nn.Module): output = utterance for layer in self.emformer_layers: output, right_context = layer( - output, - right_context, - attention_mask, + utterance=output, + right_context=right_context, + attention_mask=attention_mask, + pos_emb=pos_emb, + rel_pos=rel_pos, padding_mask=padding_mask, warmup=warmup, ) @@ -1445,6 +1607,7 @@ class EmformerEncoder(nn.Module): """ assert num_processed_frames.shape == (x.size(1),) + # check the shapes of states attn_caches = states[0] assert len(attn_caches) == self.num_encoder_layers, len(attn_caches) for i in range(len(attn_caches)): @@ -1473,6 +1636,11 @@ class EmformerEncoder(nn.Module): self.cnn_module_kernel - 1, ), conv_caches[i].shape + tot_past_length = self.memory_size * self.chunk_length + x, pos_emb = self.encoder_pos( + x, pos_len=x.size(0) + tot_past_length, neg_len=x.size(0) + ) + right_context = x[-self.right_context_length :] utterance = x[: -self.right_context_length] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) @@ -1504,6 +1672,36 @@ class EmformerEncoder(nn.Module): dim=1, ) + # calculate relative position + memory_indexes = torch.arange( + self.chunk_length // 2, tot_past_length, self.chunk_length + ) + left_context_indexes = torch.arange( + tot_past_length - self.left_context_length, tot_past_length + ) + utterance_indexes = torch.arange( + tot_past_length, tot_past_length + utterance.size(0) + ) + right_context_indexes = torch.arange( + tot_past_length + utterance.size(0), + tot_past_length + utterance.size(0) + right_context.size(0), + ) + query_indexes = torch.cat( + [right_context_indexes, utterance_indexes] + ).to(device=x.device) + key_indexes = torch.cat( + [ + memory_indexes, + right_context_indexes, + left_context_indexes, + utterance_indexes, + ] + ).to(device=x.device) + # calculate relative position and flip sign + rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)) + # shift to start from zero + rel_pos = rel_pos - rel_pos.min() + output = utterance output_attn_caches: List[List[torch.Tensor]] = [] output_conv_caches: List[torch.Tensor] = [] @@ -1514,8 +1712,10 @@ class EmformerEncoder(nn.Module): output_attn_cache, output_conv_cache, ) = layer.infer( - output, - right_context, + utterance=output, + right_context=right_context, + pos_emb=pos_emb, + rel_pos=rel_pos, padding_mask=padding_mask, attn_cache=attn_caches[layer_idx], conv_cache=conv_caches[layer_idx], @@ -1597,6 +1797,10 @@ class Emformer(EncoderInterface): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) + if memory_size > 0 and memory_size * chunk_length < left_context_length: + raise NotImplementedError( + "memory_size * chunk_length must not be smaller than left_context_length." # noqa + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -1822,3 +2026,119 @@ class Conv2dSubsampling(nn.Module): x = self.out_norm(x) x = self.out_balancer(x) return x + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py # noqa + + Suppose: + i -> position of query, + j -> position of key(value), + we use positive relative position embedding when key(value) is to the + left of query(i.e., i > j) and negative embedding otherwise. + + Args: + d_model: Embedding dimension. + dropout: Dropout rate. + max_len: Maximum input length. + """ + + def __init__( + self, d_model: int, dropout: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout) + self.pe = None + self.pos_len = max_len + self.neg_len = max_len + self.gen_pe_positive() + self.gen_pe_negative() + + def gen_pe_positive(self) -> None: + """Generate the positive positional encodings.""" + pe_positive = torch.zeros(self.pos_len, self.d_model) + position_positive = torch.arange( + 0, self.pos_len, dtype=torch.float32 + ).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe_positive[:, 0::2] = torch.sin(position_positive * div_term) + pe_positive[:, 1::2] = torch.cos(position_positive * div_term) + # Reserve the order of positive indices and concat both positive and + # negative indices. This is used to support the shifting trick + # as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa + self.pe_positive = torch.flip(pe_positive, [0]) + + def gen_pe_negative(self) -> None: + """Generate the negative positional encodings.""" + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i torch.Tensor: + """Get positional encoding given positive length and negative length.""" + if self.pe_positive.dtype != dtype or str( + self.pe_positive.device + ) != str(device): + self.pe_positive = self.pe_positive.to(dtype=dtype, device=device) + if self.pe_negative.dtype != dtype or str( + self.pe_negative.device + ) != str(device): + self.pe_negative = self.pe_negative.to(dtype=dtype, device=device) + pe = torch.cat( + [ + self.pe_positive[self.pos_len - pos_len :], + self.pe_negative[1:neg_len], + ], + dim=0, + ) + return pe + + def forward( + self, + x: torch.Tensor, + pos_len: int, + neg_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Scale input x and get positional encoding. + Args: + x (torch.Tensor): Input tensor (`*`). + + Returns: + torch.Tensor: + Encoded tensor of shape (`*`). + torch.Tensor: + Position embedding of shape (pos_len + neg_len - 1, `*`). + """ + if pos_len > self.pos_len: + self.pos_len = pos_len + self.gen_pe_positive() + if neg_len > self.neg_len: + self.neg_len = neg_len + self.gen_pe_negative() + pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype) + return self.dropout(x), self.dropout(pos_emb) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless3/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless3/test_emformer.py index 8cde6205b..91c50ea3e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless3/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless3/test_emformer.py @@ -114,8 +114,12 @@ def test_state_stack_unstack(): for _ in range(num_encoder_layers) ] states = [attn_caches, conv_caches] - x = torch.randn(batch_size, 23, num_features) - x_lens = torch.full((batch_size,), 23) + x = torch.randn( + batch_size, chunk_length + right_context_length + 3, num_features + ) + x_lens = torch.full( + (batch_size,), chunk_length + right_context_length + 3 + ) num_processed_frames = torch.full((batch_size,), 0) y, y_lens, states = model.infer( x, x_lens, num_processed_frames=num_processed_frames, states=states @@ -172,8 +176,10 @@ def test_torchscript_consistency_infer(): for _ in range(num_encoder_layers) ] states = [attn_caches, conv_caches] - x = torch.randn(batch_size, 23, num_features) - x_lens = torch.full((batch_size,), 23) + x = torch.randn( + batch_size, chunk_length + right_context_length + 3, num_features + ) + x_lens = torch.full((batch_size,), chunk_length + right_context_length + 3) num_processed_frames = torch.full((batch_size,), 0) y, y_lens, out_states = model.infer( x, x_lens, num_processed_frames=num_processed_frames, states=states @@ -187,8 +193,38 @@ def test_torchscript_consistency_infer(): assert torch.allclose(y, sc_y) +def test_emformer_forward_shape(): + num_features = 80 + chunk_length = 32 + encoder_dim = 512 + num_encoder_layers = 2 + kernel_size = 31 + left_context_length = 32 + right_context_length = 8 + memory_size = 32 + batch_size = 2 + + model = Emformer( + num_features=num_features, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=encoder_dim, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + memory_size=memory_size, + ) + U = 2 * chunk_length + x = torch.randn(batch_size, U + right_context_length + 3, num_features) + x_lens = torch.full((batch_size,), U + right_context_length + 3) + output, output_lengths = model(x, x_lens) + assert output.shape == (batch_size, U >> 2, encoder_dim) + + if __name__ == "__main__": test_convolution_module_forward() test_convolution_module_infer() test_state_stack_unstack() test_torchscript_consistency_infer() + test_emformer_forward_shape()