diff --git a/egs/librispeech/ASR/incremental_transf/.conformer.py.swp b/egs/librispeech/ASR/incremental_transf/.conformer.py.swp index d8c62d0cc..c70f2b921 100644 Binary files a/egs/librispeech/ASR/incremental_transf/.conformer.py.swp and b/egs/librispeech/ASR/incremental_transf/.conformer.py.swp differ diff --git a/egs/librispeech/ASR/incremental_transf/conformer.py b/egs/librispeech/ASR/incremental_transf/conformer.py index 04f3b78ed..ca1137f31 100644 --- a/egs/librispeech/ASR/incremental_transf/conformer.py +++ b/egs/librispeech/ASR/incremental_transf/conformer.py @@ -510,67 +510,11 @@ class Tempformer(EncoderInterface): def forward( self, x ): - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, d_model) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - lengths = (((x_lens - 1) >> 1) - 1) >> 1 - assert x.size(0) == lengths.max().item() - src_key_padding_mask = make_pad_mask(lengths) - - if self.dynamic_chunk_training: - assert ( - self.causal - ), "Causal convolution is required for streaming conformer." - max_len = x.size(0) - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - chunk_size = max_len - else: - chunk_size = chunk_size % self.short_chunk_size + 1 - - mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, - device=x.device, - ) - x = self.encoder( - x, - pos_emb, - mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - else: - x, layer_outputs = self.encoder( - x, - pos_emb, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs] - - if get_layer_output: - return x, lengths, layer_outputs - else: - return x, lengths + layer_outputs = [] + for enum, encoder in enumerate(self.encoder_layers): + layer_outputs.append(encoder(x[enum])) + return layer_outputs class ConformerEncoderLayer(nn.Module):