mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Support causal convolution in emformer encoder layer.
This commit is contained in:
parent
a24eef8096
commit
1d74c5e596
@ -458,6 +458,8 @@ class EmformerLayer(nn.Module):
|
||||
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||
negative_inf (float, optional):
|
||||
Value to use for negative infinity in attention weights. (Default: -1e8)
|
||||
causal (bool):
|
||||
Whether use causal convolution (default=False).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -472,6 +474,7 @@ class EmformerLayer(nn.Module):
|
||||
max_memory_size: int = 0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -500,7 +503,11 @@ class EmformerLayer(nn.Module):
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
self.conv_module = ConvolutionModule(
|
||||
d_model,
|
||||
cnn_module_kernel,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(d_model)
|
||||
self.norm_ff = nn.LayerNorm(d_model)
|
||||
@ -910,6 +917,8 @@ class EmformerEncoder(nn.Module):
|
||||
If ``true``, applies tanh to memory elements. (default: ``false``)
|
||||
negative_inf (float, optional):
|
||||
Value to use for negative infinity in attention weights. (default: -1e8)
|
||||
causal (bool):
|
||||
Whether use causal convolution (default=False).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -926,6 +935,7 @@ class EmformerEncoder(nn.Module):
|
||||
max_memory_size: int = 0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -949,6 +959,7 @@ class EmformerEncoder(nn.Module):
|
||||
max_memory_size=max_memory_size,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
causal=causal,
|
||||
)
|
||||
for layer_idx in range(num_encoder_layers)
|
||||
]
|
||||
@ -1220,6 +1231,7 @@ class Emformer(EncoderInterface):
|
||||
max_memory_size: int = 0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1261,6 +1273,7 @@ class Emformer(EncoderInterface):
|
||||
max_memory_size=max_memory_size,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
# TODO(fangjun): remove dropout
|
||||
@ -1366,14 +1379,22 @@ class ConvolutionModule(nn.Module):
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
channels (int):
|
||||
The number of channels of conv layers.
|
||||
kernel_size (int):
|
||||
Kernerl size of conv layers.
|
||||
bias (bool):
|
||||
Whether to use bias in conv layers (default=True).
|
||||
causal (bool):
|
||||
Whether use causal convolution (default=False).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
@ -1388,12 +1409,19 @@ class ConvolutionModule(nn.Module):
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if causal:
|
||||
self.left_padding = kernel_size - 1
|
||||
padding = 0
|
||||
else:
|
||||
self.left_padding = 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
@ -1426,6 +1454,10 @@ class ConvolutionModule(nn.Module):
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
if self.left_padding > 0:
|
||||
# manualy padding self.lorder zeros to the left
|
||||
# make depthwise_conv causal
|
||||
x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0)
|
||||
x = self.depthwise_conv(x)
|
||||
# x is (batch, channels, time)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
@ -103,6 +103,7 @@ def test_emformer_layer_forward():
|
||||
cnn_module_kernel=3,
|
||||
left_context_length=L,
|
||||
max_memory_size=M,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
Q, KV = R + U + S, M + R + U
|
||||
@ -147,6 +148,7 @@ def test_emformer_layer_infer():
|
||||
cnn_module_kernel=3,
|
||||
left_context_length=L,
|
||||
max_memory_size=M,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
@ -203,6 +205,7 @@ def test_emformer_encoder_forward():
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
x = torch.randn(U + R, B, D)
|
||||
@ -239,6 +242,7 @@ def test_emformer_encoder_infer():
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
states = None
|
||||
@ -284,6 +288,7 @@ def test_emformer_forward():
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
vgg_frontend=False,
|
||||
causal=True,
|
||||
)
|
||||
x = torch.randn(B, U + R + 3, num_features)
|
||||
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
|
||||
@ -324,6 +329,7 @@ def test_emformer_infer():
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
vgg_frontend=False,
|
||||
causal=True,
|
||||
)
|
||||
states = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
|
@ -137,6 +137,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Number of entries in the memory for the Emformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal-conv",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether use causal convolution.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -377,6 +384,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
chunk_length=params.chunk_length,
|
||||
right_context_length=params.right_context_length,
|
||||
max_memory_size=params.memory_size,
|
||||
causal=params.causal_conv,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user