Support causal convolution in emformer encoder layer.

This commit is contained in:
yaozengwei 2022-04-11 12:28:15 +08:00
parent a24eef8096
commit 1d74c5e596
3 changed files with 53 additions and 7 deletions

View File

@ -458,6 +458,8 @@ class EmformerLayer(nn.Module):
If ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional):
Value to use for negative infinity in attention weights. (Default: -1e8)
causal (bool):
Whether use causal convolution (default=False).
"""
def __init__(
@ -472,6 +474,7 @@ class EmformerLayer(nn.Module):
max_memory_size: int = 0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
causal: bool = False,
):
super().__init__()
@ -500,7 +503,11 @@ class EmformerLayer(nn.Module):
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(
d_model,
cnn_module_kernel,
causal=causal,
)
self.norm_ff_macaron = nn.LayerNorm(d_model)
self.norm_ff = nn.LayerNorm(d_model)
@ -910,6 +917,8 @@ class EmformerEncoder(nn.Module):
If ``true``, applies tanh to memory elements. (default: ``false``)
negative_inf (float, optional):
Value to use for negative infinity in attention weights. (default: -1e8)
causal (bool):
Whether use causal convolution (default=False).
"""
def __init__(
@ -926,6 +935,7 @@ class EmformerEncoder(nn.Module):
max_memory_size: int = 0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
causal: bool = False,
):
super().__init__()
@ -949,6 +959,7 @@ class EmformerEncoder(nn.Module):
max_memory_size=max_memory_size,
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
causal=causal,
)
for layer_idx in range(num_encoder_layers)
]
@ -1220,6 +1231,7 @@ class Emformer(EncoderInterface):
max_memory_size: int = 0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
causal: bool = False,
):
super().__init__()
@ -1261,6 +1273,7 @@ class Emformer(EncoderInterface):
max_memory_size=max_memory_size,
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
causal=causal,
)
# TODO(fangjun): remove dropout
@ -1366,14 +1379,22 @@ class ConvolutionModule(nn.Module):
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
channels (int):
The number of channels of conv layers.
kernel_size (int):
Kernerl size of conv layers.
bias (bool):
Whether to use bias in conv layers (default=True).
causal (bool):
Whether use causal convolution (default=False).
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = False,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
@ -1388,12 +1409,19 @@ class ConvolutionModule(nn.Module):
padding=0,
bias=bias,
)
if causal:
self.left_padding = kernel_size - 1
padding = 0
else:
self.left_padding = 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
padding=padding,
groups=channels,
bias=bias,
)
@ -1426,6 +1454,10 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if self.left_padding > 0:
# manualy padding self.lorder zeros to the left
# make depthwise_conv causal
x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)

View File

@ -103,6 +103,7 @@ def test_emformer_layer_forward():
cnn_module_kernel=3,
left_context_length=L,
max_memory_size=M,
causal=True,
)
Q, KV = R + U + S, M + R + U
@ -147,6 +148,7 @@ def test_emformer_layer_infer():
cnn_module_kernel=3,
left_context_length=L,
max_memory_size=M,
causal=True,
)
utterance = torch.randn(U, B, D)
@ -203,6 +205,7 @@ def test_emformer_encoder_forward():
left_context_length=L,
right_context_length=R,
max_memory_size=M,
causal=True,
)
x = torch.randn(U + R, B, D)
@ -239,6 +242,7 @@ def test_emformer_encoder_infer():
left_context_length=L,
right_context_length=R,
max_memory_size=M,
causal=True,
)
states = None
@ -284,6 +288,7 @@ def test_emformer_forward():
right_context_length=R,
max_memory_size=M,
vgg_frontend=False,
causal=True,
)
x = torch.randn(B, U + R + 3, num_features)
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
@ -324,6 +329,7 @@ def test_emformer_infer():
right_context_length=R,
max_memory_size=M,
vgg_frontend=False,
causal=True,
)
states = None
for chunk_idx in range(num_chunks):

View File

@ -137,6 +137,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of entries in the memory for the Emformer",
)
parser.add_argument(
"--causal-conv",
type=bool,
default=True,
help="Whether use causal convolution.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -377,6 +384,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
chunk_length=params.chunk_length,
right_context_length=params.right_context_length,
max_memory_size=params.memory_size,
causal=params.causal_conv,
)
return encoder