diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index c55a73d68..032ecb77d 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -458,6 +458,8 @@ class EmformerLayer(nn.Module): If ``True``, applies tanh to memory elements. (Default: ``False``) negative_inf (float, optional): Value to use for negative infinity in attention weights. (Default: -1e8) + causal (bool): + Whether use causal convolution (default=False). """ def __init__( @@ -472,6 +474,7 @@ class EmformerLayer(nn.Module): max_memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + causal: bool = False, ): super().__init__() @@ -500,7 +503,11 @@ class EmformerLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule( + d_model, + cnn_module_kernel, + causal=causal, + ) self.norm_ff_macaron = nn.LayerNorm(d_model) self.norm_ff = nn.LayerNorm(d_model) @@ -910,6 +917,8 @@ class EmformerEncoder(nn.Module): If ``true``, applies tanh to memory elements. (default: ``false``) negative_inf (float, optional): Value to use for negative infinity in attention weights. (default: -1e8) + causal (bool): + Whether use causal convolution (default=False). """ def __init__( @@ -926,6 +935,7 @@ class EmformerEncoder(nn.Module): max_memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + causal: bool = False, ): super().__init__() @@ -949,6 +959,7 @@ class EmformerEncoder(nn.Module): max_memory_size=max_memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + causal=causal, ) for layer_idx in range(num_encoder_layers) ] @@ -1220,6 +1231,7 @@ class Emformer(EncoderInterface): max_memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + causal: bool = False, ): super().__init__() @@ -1261,6 +1273,7 @@ class Emformer(EncoderInterface): max_memory_size=max_memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + causal=causal, ) # TODO(fangjun): remove dropout @@ -1366,14 +1379,22 @@ class ConvolutionModule(nn.Module): Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - + channels (int): + The number of channels of conv layers. + kernel_size (int): + Kernerl size of conv layers. + bias (bool): + Whether to use bias in conv layers (default=True). + causal (bool): + Whether use causal convolution (default=False). """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False, ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() @@ -1388,12 +1409,19 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) + + if causal: + self.left_padding = kernel_size - 1 + padding = 0 + else: + self.left_padding = 0 + padding = (kernel_size - 1) // 2 self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, stride=1, - padding=(kernel_size - 1) // 2, + padding=padding, groups=channels, bias=bias, ) @@ -1426,6 +1454,10 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if self.left_padding > 0: + # manualy padding self.lorder zeros to the left + # make depthwise_conv causal + x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0) x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py index 1f735637f..7685bfb26 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py @@ -103,6 +103,7 @@ def test_emformer_layer_forward(): cnn_module_kernel=3, left_context_length=L, max_memory_size=M, + causal=True, ) Q, KV = R + U + S, M + R + U @@ -147,6 +148,7 @@ def test_emformer_layer_infer(): cnn_module_kernel=3, left_context_length=L, max_memory_size=M, + causal=True, ) utterance = torch.randn(U, B, D) @@ -203,6 +205,7 @@ def test_emformer_encoder_forward(): left_context_length=L, right_context_length=R, max_memory_size=M, + causal=True, ) x = torch.randn(U + R, B, D) @@ -239,6 +242,7 @@ def test_emformer_encoder_infer(): left_context_length=L, right_context_length=R, max_memory_size=M, + causal=True, ) states = None @@ -284,6 +288,7 @@ def test_emformer_forward(): right_context_length=R, max_memory_size=M, vgg_frontend=False, + causal=True, ) x = torch.randn(B, U + R + 3, num_features) x_lens = torch.randint(1, U + R + 3 + 1, (B,)) @@ -324,6 +329,7 @@ def test_emformer_infer(): right_context_length=R, max_memory_size=M, vgg_frontend=False, + causal=True, ) states = None for chunk_idx in range(num_chunks): diff --git a/egs/librispeech/ASR/conv_emformer_transducer/train.py b/egs/librispeech/ASR/conv_emformer_transducer/train.py index bdb541ac6..d0126bb94 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/train.py @@ -137,6 +137,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of entries in the memory for the Emformer", ) + parser.add_argument( + "--causal-conv", + type=bool, + default=True, + help="Whether use causal convolution.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -377,6 +384,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: chunk_length=params.chunk_length, right_context_length=params.right_context_length, max_memory_size=params.memory_size, + causal=params.causal_conv, ) return encoder