From 9423b3899fccb321f974531f0b72a88a1518abf8 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 4 Apr 2022 22:16:46 +0800 Subject: [PATCH] Update emformer_pruned_transducer_stateless/emformer.py and upload emformer_pruned_transducer_stateless/test_emformer.py. --- .../emformer.py | 176 ++++----- .../test_emformer.py | 345 ++++++++++++++++++ 2 files changed, 408 insertions(+), 113 deletions(-) create mode 100644 egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 88b1a06fb..32498a2c1 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -9,48 +9,6 @@ from encoder_interface import EncoderInterface from subsampling import Conv2dSubsampling, VggSubsampling -def _gen_padding_mask( - utterance: torch.Tensor, - right_context: torch.Tensor, - lengths: torch.Tensor, - mems: torch.Tensor, - left_context_key: Optional[torch.Tensor] = None, -) -> Optional[torch.Tensor]: - """Generate padding mask according to the length of the tensors - contained in the key. - - Args: - utterance: (U, B, D) - right_context: (R, B, D) - lengths: (B,) - mems: (M, B, D) - left_context_key: (L, B, D) - B is the batch size, D is the feature dimension, - U is the length of the utterance, - R is the length of the right context block, - M is the length of the memory block, - L is the length of the left context block - - Returns: - padding_mask: - Padding mask for the concatenated key tensor - [mems, right_context, left_context, utterance], - sharing for all queries, with shape of (M + R + L + U, B) - """ - assert utterance.size(0) == torch.max(lengths) - B = utterance.size(1) - M = mems.size(0) - R = right_context.size(0) - L = left_context_key.size(0) if left_context_key is not None else 0 - if B == 1: - # TODO: for infer mode? - padding_mask = None - else: - lengths_concat = M + R + L + lengths - padding_mask = make_pad_mask(lengths_concat) - return padding_mask - - def _get_activation_module(activation: str) -> nn.Module: if activation == "relu": return nn.ReLU() @@ -96,11 +54,6 @@ def _gen_attention_mask_block( return torch.cat(mask_block, dim=1) -def length_down_sampling(length): - # Caution: We assume the subsampling factor is 4! - return ((length - 1) // 2 - 1) // 2 - - class EmformerAttention(nn.Module): r"""Emformer layer attention module. @@ -239,7 +192,7 @@ class EmformerAttention(nn.Module): and compute query tensor with length Q = R + U + S. 2) Concat memory, right_context, utterance, and compute key, value tensors with length KV = M + R + U; - optionally with left_context_key and left_context_val (inference mode) + optionally with left_context_key and left_context_val (inference mode), then KV = M + R + L + U. 3) Compute entire attention scores with query, key, and value, then apply attention_mask to get underlying chunk-wise attention scores. @@ -284,7 +237,7 @@ class EmformerAttention(nn.Module): ).chunk(chunks=2, dim=2) if left_context_key is not None and left_context_val is not None: - # Now compute key and value with + # This is for inference mode. Now compute key and value with # [mems, right context, left context, uttrance] M = memory.size(0) R = right_context.size(0) @@ -328,8 +281,8 @@ class EmformerAttention(nn.Module): outputs = self.out_proj(attention) S = summary.size(0) - output_right_context_utterance = outputs[:-S] - output_memory = outputs[-S:] + output_right_context_utterance = outputs[:Q - S] + output_memory = outputs[Q - S:] if self.tanh_on_mem: output_memory = torch.tanh(output_memory) else: @@ -370,12 +323,12 @@ class EmformerAttention(nn.Module): Memory elements, with shape (M, B, D). attention_mask (torch.Tensor): Attention mask for underlying chunk-wise attention, - with shape (Q, KV). + with shape (Q, KV), where Q = R + U + S, KV = M + R + U. Returns: A tuple containing 2 tensors: - output of right context and utterance, with shape (R + U, B, D). - - memory output, with shape (M, B, D), where M = S - 1. + - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ output_right_context_utterance, output_memory, _, _ = \ self._forward_impl( @@ -418,7 +371,7 @@ class EmformerAttention(nn.Module): right_context (torch.Tensor): Right context frames, with shape (R, B, D). summary (torch.Tensor): - Summary elements, with shape (S, B, D). + Summary element, with shape (1, B, D), or empty. memory (torch.Tensor): Memory elements, with shape (M, B, D). left_context_key (torch,Tensor): @@ -431,7 +384,7 @@ class EmformerAttention(nn.Module): Returns: A tuple containing 4 tensors: - output of right context and utterance, with shape (R + U, B, D). - - memory output, with shape (S, B, D). + - memory output, with shape (1, B, D) or (0, B, D). - attention key of left context and utterance, which would be cached for next computation, with shape (L + U, B, D). - attention value of left context and utterance, which would be @@ -476,7 +429,7 @@ class EmformerLayer(nn.Module): Number of attention heads. dim_feedforward (int): Hidden layer dimension of feedforward network. - segment_length (int): + chunk_length (int): Length of each input segment. dropout (float, optional): Dropout probability. (Default: 0.0) @@ -501,7 +454,7 @@ class EmformerLayer(nn.Module): d_model: int, nhead: int, dim_feedforward: int, - segment_length: int, + chunk_length: int, dropout: float = 0.0, activation: str = "relu", left_context_length: int = 0, @@ -513,7 +466,7 @@ class EmformerLayer(nn.Module): super().__init__() self.attention = EmformerAttention( - d_model=d_model, + embed_dim=d_model, nhead=nhead, dropout=dropout, weight_init_gain=weight_init_gain, @@ -522,7 +475,7 @@ class EmformerLayer(nn.Module): ) self.dropout = nn.Dropout(dropout) self.summary_op = nn.AvgPool1d( - kernel_size=segment_length, stride=segment_length, ceil_mode=True + kernel_size=chunk_length, stride=chunk_length, ceil_mode=True ) activation_module = _get_activation_module(activation) @@ -538,7 +491,7 @@ class EmformerLayer(nn.Module): self.layer_norm_output = nn.LayerNorm(d_model) self.left_context_length = left_context_length - self.segment_length = segment_length + self.chunk_length = chunk_length self.max_memory_size = max_memory_size self.d_model = d_model @@ -576,11 +529,13 @@ class EmformerLayer(nn.Module): past_length = state[3][0][0].item() past_left_context_length = min(self.left_context_length, past_length) past_memory_length = min( - self.max_memory_size, math.ceil(past_length / self.segment_length) + self.max_memory_size, math.ceil(past_length / self.chunk_length) ) - pre_memory = state[0][-past_memory_length:] - left_context_key = state[1][-past_left_context_length:] - left_context_val = state[2][-past_left_context_length:] + pre_memory = state[0][self.max_memory_size - past_memory_length:] + left_context_key = \ + state[1][self.left_context_length - past_left_context_length:] + left_context_val = \ + state[2][self.left_context_length - past_left_context_length:] return pre_memory, left_context_key, left_context_val def _pack_state( @@ -600,9 +555,9 @@ class EmformerLayer(nn.Module): new_memory = torch.cat([state[0], memory]) new_key = torch.cat([state[1], next_key]) new_val = torch.cat([state[2], next_val]) - state[0] = new_memory[-self.max_memory_size:] - state[1] = new_key[-self.left_context_length:] - state[2] = new_val[-self.left_context_length:] + state[0] = new_memory[new_memory.size(0) - self.max_memory_size:] + state[1] = new_key[new_key.size(0) - self.left_context_length:] + state[2] = new_val[new_val.size(0) - self.left_context_length:] state[3] = state[3] + update_length return state @@ -749,7 +704,8 @@ class EmformerLayer(nn.Module): memory (torch.Tensor): Memory elements, with shape (M, B, D). attention_mask (torch.Tensor): - Attention mask for underlying attention module. + Attention mask for underlying attention module, + with shape (Q, KV), where Q = R + U + S, KV = M + R + U. Returns: A tuple containing 3 tensors: @@ -819,7 +775,7 @@ class EmformerLayer(nn.Module): (Tensor, Tensor, List[torch.Tensor], Tensor): - output utterance, with shape (U, B, D); - output right_context, with shape (R, B, D); - - output memory, with shape (M, B, D); + - output memory, with shape (1, B, D) or (0, B, D). - output state. """ ( @@ -883,15 +839,6 @@ class EmformerEncoder(nn.Module): If ``true``, applies tanh to memory elements. (default: ``false``) negative_inf (float, optional): Value to use for negative infinity in attention weights. (default: -1e8) - - examples: - >>> emformer = emformer(512, 8, 2048, 20, 4, right_context_length=1) - >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim - >>> lengths = torch.randint(1, 200, (128,)) # batch - >>> output = emformer(input, lengths) - >>> input = torch.rand(128, 5, 512) - >>> lengths = torch.ones(128) * 5 - >>> output, lengths, states = emformer.infer(input, lengths, None) """ def __init__( @@ -913,7 +860,7 @@ class EmformerEncoder(nn.Module): super().__init__() self.use_memory = max_memory_size > 0 - self.memory_op = nn.AvgPool1d( + self.init_memory_op = nn.AvgPool1d( kernel_size=chunk_length, stride=chunk_length, ceil_mode=True, @@ -957,7 +904,7 @@ class EmformerEncoder(nn.Module): start = (seg_idx + 1) * self.chunk_length end = start + self.right_context_length right_context_blocks.append(x[start:end]) - right_context_blocks.append(x[-self.right_context_length:]) + right_context_blocks.append(x[T - self.right_context_length:]) return torch.cat(right_context_blocks) def _gen_attention_mask_col_widths( @@ -1095,31 +1042,34 @@ class EmformerEncoder(nn.Module): with shape (U + right_context_length, B, D). lengths (torch.Tensor): With shape (B,) and i-th element representing number of valid - utterance frames for i-th batch element in x. - It is the true lengths without containing the right_context. + utterance frames for i-th batch element in x, which contains the + right_context at the end. Returns: - (Tensor, Tensor): + A tuple of 2 tensors: - output utterance frames, with shape (U, B, D). - - output lengths, with shape (B,) and i-th element representing - number of valid frames for i-th batch element in output frames. + - output_lengths, with shape (B,), without containing the + right_context at the end. """ - assert x.size(0) == torch.max(lengths).item() + \ - self.right_context_length + # assert x.size(0) == torch.max(lengths).item() right_context = self._gen_right_context(x) - utterance = x[:-self.right_context_length] + utterance = x[:x.size(0) - self.right_context_length] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + self.init_memory_op( + utterance.permute(1, 2, 0) + ).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) output = utterance for layer in self.emformer_layers: - output, right_context, memory = \ - layer(output, lengths, right_context, memory, attention_mask) + output, right_context, memory = layer( + output, output_lengths, right_context, memory, attention_mask + ) - return output, lengths + return output, output_lengths @torch.jit.export def infer( @@ -1137,11 +1087,11 @@ class EmformerEncoder(nn.Module): Args: x (torch.Tensor): Utterance frames right-padded with right context frames, - with shape (chunk_length + right_context_length, B, D). + with shape (U + right_context_length, B, D). lengths (torch.Tensor): With shape (B,) and i-th element representing number of valid - utterance frames for i-th batch element in x. - It contains the right_context. + utterance frames for i-th batch element in x, which contains the + right_context at the end. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each element (List[torch.Tensor]) corresponding to each emformer layer. @@ -1150,8 +1100,8 @@ class EmformerEncoder(nn.Module): Returns: (Tensor, Tensor, List[List[torch.Tensor]]): - output utterance frames, with shape (U, B, D). - - output lengths, with shape (B,) and i-th element representing - number of valid frames for i-th batch element in output frames. + - output lengths, with shape (B,), without containing the + right_context at the end. - updated states from current chunk's computation. """ assert x.size(0) == self.chunk_length + self.right_context_length, ( @@ -1159,23 +1109,24 @@ class EmformerEncoder(nn.Module): f"expected size of {self.chunk_length + self.right_context_length} " f"for dimension 1 of x, but got {x.size(1)}." ) - right_context = x[-self.right_context_length:] - utterance = x[:-self.right_context_length] + right_context_start_idx = x.size(0) - self.right_context_length + right_context = x[right_context_start_idx:] + utterance = x[:right_context_start_idx] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) memory = ( - self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) output = utterance output_states: List[List[torch.Tensor]] = [] for layer_idx, layer in enumerate(self.emformer_layers): - output, right_context, output_state, memory = layer.infer( + output, right_context, memory, output_state = layer.infer( output, output_lengths, right_context, - None if states is None else states[layer_idx], memory, + None if states is None else states[layer_idx], ) output_states.append(output_state) @@ -1272,24 +1223,23 @@ class Emformer(EncoderInterface): with shape (B, U + right_context_length, D). x_lens (torch.Tensor): With shape (B,) and i-th element representing number of valid - utterance frames for i-th batch element in x. - It is the true lengths without containing the right_context. + utterance frames for i-th batch element in x, containing the + right_context at the end. Returns: (Tensor, Tensor): - output logits, with shape (B, U // 4, D). - - logits lengths, with shape (B,) and i-th element representing - number of valid frames for i-th batch element in output frames. + - logits lengths, with shape (B,), without containing the + right_context at the end. """ + # TODO: x.shape x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! lengths = x_lens // 4 - assert x.size(0) == \ - lengths.max().item() + self.right_context_length // 4 - + assert x.size(0) == lengths.max().item() output, output_lengths = self.encoder(x, lengths) # (T, N, C) logits = self.encoder_output_layer(output) @@ -1316,8 +1266,8 @@ class Emformer(EncoderInterface): with shape (B, U + right_context_length, D). lengths (torch.Tensor): With shape (B,) and i-th element representing number of valid - utterance frames for i-th batch element in x. - It is the true lengths without containing the right_context. + utterance frames for i-th batch element in x, containing the + right_context at the end. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each element (List[torch.Tensor]) corresponding to each emformer layer. @@ -1325,8 +1275,8 @@ class Emformer(EncoderInterface): Returns: (Tensor, Tensor): - output logits, with shape (B, U // 4, D). - - logits lengths, with shape (B,) and i-th element representing - number of valid frames for i-th batch element in output frames. + - logits lengths, with shape (B,), without containing the + right_context at the end. - updated states from current chunk's computation. """ x = self.encoder_embed(x) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py new file mode 100644 index 000000000..ae93a4c8f --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -0,0 +1,345 @@ +import torch + + +def test_emformer_attention_forward(): + from emformer import EmformerAttention + + B, D = 2, 256 + U, R = 12, 2 + chunk_length = 2 + attention = EmformerAttention(embed_dim=D, nhead=8) + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + Q, KV = R + U + S, M + R + U + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + summary = torch.randn(S, B, D) + memory = torch.randn(M, B, D) + attention_mask = torch.rand(Q, KV) >= 0.5 + + output_right_context_utterance, output_memory = attention( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + ) + assert output_right_context_utterance.shape == (R + U, B, D) + assert output_memory.shape == (M, B, D) + + +def test_emformer_attention_infer(): + from emformer import EmformerAttention + + B, D = 2, 256 + R, L = 4, 2 + chunk_length = 2 + U = chunk_length + attention = EmformerAttention(embed_dim=D, nhead=8) + + for use_memory in [True, False]: + if use_memory: + S, M = 1, 3 + else: + S, M = 0, 0 + + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + summary = torch.randn(S, B, D) + memory = torch.randn(M, B, D) + left_context_key = torch.randn(L, B, D) + left_context_val = torch.randn(L, B, D) + + output_right_context_utterance, output_memory, next_key, next_val = \ + attention.infer( + utterance, + lengths, + right_context, + summary, + memory, + left_context_key, + left_context_val, + ) + assert output_right_context_utterance.shape == (R + U, B, D) + assert output_memory.shape == (S, B, D) + assert next_key.shape == (L + U, B, D) + assert next_val.shape == (L + U, B, D) + + +def test_emformer_layer_forward(): + from emformer import EmformerLayer + + B, D = 2, 256 + U, R, L = 12, 2, 5 + chunk_length = 2 + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + layer = EmformerLayer( + d_model=D, + nhead=8, + dim_feedforward=1024, + chunk_length=chunk_length, + left_context_length=L, + max_memory_size=M, + ) + + Q, KV = R + U + S, M + R + U + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + memory = torch.randn(M, B, D) + attention_mask = torch.rand(Q, KV) >= 0.5 + + output_utterance, output_right_context, output_memory = layer( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + assert output_utterance.shape == (U, B, D) + assert output_right_context.shape == (R, B, D) + assert output_memory.shape == (M, B, D) + + +def test_emformer_layer_infer(): + from emformer import EmformerLayer + + B, D = 2, 256 + R, L = 2, 5 + chunk_length = 2 + U = chunk_length + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + + layer = EmformerLayer( + d_model=D, + nhead=8, + dim_feedforward=1024, + chunk_length=chunk_length, + left_context_length=L, + max_memory_size=M, + ) + + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + memory = torch.randn(M, B, D) + state = None + output_utterance, output_right_context, output_memory, output_state = \ + layer.infer( + utterance, + lengths, + right_context, + memory, + state, + ) + assert output_utterance.shape == (U, B, D) + assert output_right_context.shape == (R, B, D) + if use_memory: + assert output_memory.shape == (1, B, D) + else: + assert output_memory.shape == (0, B, D) + assert len(output_state) == 4 + assert output_state[0].shape == (M, B, D) + assert output_state[1].shape == (L, B, D) + assert output_state[2].shape == (L, B, D) + assert output_state[3].shape == (1, B) + + +def test_emformer_encoder_forward(): + from emformer import EmformerEncoder + + B, D = 2, 256 + U, R, L = 12, 2, 5 + chunk_length = 2 + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=2, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + ) + + x = torch.randn(U + R, B, D) + lengths = torch.randint(1, U + R + 1, (B,)) + lengths[0] = U + R + + output, output_lengths = encoder(x, lengths) + assert output.shape == (U, B, D) + assert torch.equal( + output_lengths, torch.clamp(lengths - R, min=0) + ) + + +def test_emformer_encoder_infer(): + from emformer import EmformerEncoder + + B, D = 2, 256 + R, L = 2, 5 + chunk_length = 2 + U = chunk_length + num_chunks = 3 + num_encoder_layers = 2 + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + ) + + states = None + for chunk_idx in range(num_chunks): + x = torch.randn(U + R, B, D) + lengths = torch.randint(1, U + R + 1, (B,)) + lengths[0] = U + R + output, output_lengths, states = \ + encoder.infer(x, lengths, states) + assert output.shape == (U, B, D) + assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) + assert len(states) == num_encoder_layers + for state in states: + assert len(state) == 4 + assert state[0].shape == (M, B, D) + assert state[1].shape == (L, B, D) + assert state[2].shape == (L, B, D) + assert torch.equal( + state[3], (chunk_idx + 1) * U * torch.ones_like(state[3]) + ) + + +def test_emformer_forward(): + from emformer import Emformer + num_features = 80 + output_dim = 1000 + chunk_length = 16 + L, R = 32, 16 + B, D, U = 2, 256, 48 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + x = torch.randn(B, U + R, num_features) + x_lens = torch.randint(1, U + R + 1, (B,)) + x_lens[0] = U + R + logits, output_lengths = model(x, x_lens) + assert logits.shape == (B, U // 4, output_dim) + assert torch.equal( + output_lengths, torch.clamp(x_lens // 4 - R // 4, min=0) + ) + + +def test_emformer_infer(): + from emformer import Emformer + num_features = 80 + output_dim = 1000 + chunk_length = 16 + U = chunk_length + L, R = 32, 16 + B, D = 2, 256 + num_chunks = 3 + num_encoder_layers = 2 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + states = None + for chunk_idx in range(num_chunks): + x = torch.randn(B, U + R, num_features) + x_lens = torch.randint(1, U + R + 1, (B,)) + x_lens[0] = U + R + logits, output_lengths, states = \ + model.infer(x, x_lens, states) + assert logits.shape == (B, U // 4, output_dim) + assert torch.equal( + output_lengths, torch.clamp(x_lens // 4 - R // 4, min=0) + ) + assert len(states) == num_encoder_layers + for state in states: + assert len(state) == 4 + assert state[0].shape == (M, B, D) + assert state[1].shape == (L // 4, B, D) + assert state[2].shape == (L // 4, B, D) + assert torch.equal( + state[3], + (chunk_idx + 1) * U // 4 * torch.ones_like(state[3]) + ) + + +if __name__ == "__main__": + test_emformer_attention_forward() + test_emformer_attention_infer() + test_emformer_layer_forward() + test_emformer_layer_infer() + test_emformer_encoder_forward() + test_emformer_encoder_infer() + test_emformer_forward() + test_emformer_infer()