mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
Support causal convolution in emformer encoder layer.
This commit is contained in:
parent
a24eef8096
commit
1d74c5e596
@ -458,6 +458,8 @@ class EmformerLayer(nn.Module):
|
|||||||
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||||
negative_inf (float, optional):
|
negative_inf (float, optional):
|
||||||
Value to use for negative infinity in attention weights. (Default: -1e8)
|
Value to use for negative infinity in attention weights. (Default: -1e8)
|
||||||
|
causal (bool):
|
||||||
|
Whether use causal convolution (default=False).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -472,6 +474,7 @@ class EmformerLayer(nn.Module):
|
|||||||
max_memory_size: int = 0,
|
max_memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
|
causal: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -500,7 +503,11 @@ class EmformerLayer(nn.Module):
|
|||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(
|
||||||
|
d_model,
|
||||||
|
cnn_module_kernel,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
self.norm_ff_macaron = nn.LayerNorm(d_model)
|
self.norm_ff_macaron = nn.LayerNorm(d_model)
|
||||||
self.norm_ff = nn.LayerNorm(d_model)
|
self.norm_ff = nn.LayerNorm(d_model)
|
||||||
@ -910,6 +917,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
If ``true``, applies tanh to memory elements. (default: ``false``)
|
If ``true``, applies tanh to memory elements. (default: ``false``)
|
||||||
negative_inf (float, optional):
|
negative_inf (float, optional):
|
||||||
Value to use for negative infinity in attention weights. (default: -1e8)
|
Value to use for negative infinity in attention weights. (default: -1e8)
|
||||||
|
causal (bool):
|
||||||
|
Whether use causal convolution (default=False).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -926,6 +935,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
max_memory_size: int = 0,
|
max_memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
|
causal: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -949,6 +959,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
max_memory_size=max_memory_size,
|
max_memory_size=max_memory_size,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
|
causal=causal,
|
||||||
)
|
)
|
||||||
for layer_idx in range(num_encoder_layers)
|
for layer_idx in range(num_encoder_layers)
|
||||||
]
|
]
|
||||||
@ -1220,6 +1231,7 @@ class Emformer(EncoderInterface):
|
|||||||
max_memory_size: int = 0,
|
max_memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
|
causal: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1261,6 +1273,7 @@ class Emformer(EncoderInterface):
|
|||||||
max_memory_size=max_memory_size,
|
max_memory_size=max_memory_size,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
|
causal=causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(fangjun): remove dropout
|
# TODO(fangjun): remove dropout
|
||||||
@ -1366,14 +1379,22 @@ class ConvolutionModule(nn.Module):
|
|||||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channels (int): The number of channels of conv layers.
|
channels (int):
|
||||||
kernel_size (int): Kernerl size of conv layers.
|
The number of channels of conv layers.
|
||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
kernel_size (int):
|
||||||
|
Kernerl size of conv layers.
|
||||||
|
bias (bool):
|
||||||
|
Whether to use bias in conv layers (default=True).
|
||||||
|
causal (bool):
|
||||||
|
Whether use causal convolution (default=False).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, channels: int, kernel_size: int, bias: bool = True
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
@ -1388,12 +1409,19 @@ class ConvolutionModule(nn.Module):
|
|||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
self.left_padding = kernel_size - 1
|
||||||
|
padding = 0
|
||||||
|
else:
|
||||||
|
self.left_padding = 0
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
self.depthwise_conv = nn.Conv1d(
|
self.depthwise_conv = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=(kernel_size - 1) // 2,
|
padding=padding,
|
||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
@ -1426,6 +1454,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
|
if self.left_padding > 0:
|
||||||
|
# manualy padding self.lorder zeros to the left
|
||||||
|
# make depthwise_conv causal
|
||||||
|
x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0)
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
# x is (batch, channels, time)
|
# x is (batch, channels, time)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
|
@ -103,6 +103,7 @@ def test_emformer_layer_forward():
|
|||||||
cnn_module_kernel=3,
|
cnn_module_kernel=3,
|
||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
Q, KV = R + U + S, M + R + U
|
Q, KV = R + U + S, M + R + U
|
||||||
@ -147,6 +148,7 @@ def test_emformer_layer_infer():
|
|||||||
cnn_module_kernel=3,
|
cnn_module_kernel=3,
|
||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
utterance = torch.randn(U, B, D)
|
utterance = torch.randn(U, B, D)
|
||||||
@ -203,6 +205,7 @@ def test_emformer_encoder_forward():
|
|||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = torch.randn(U + R, B, D)
|
x = torch.randn(U + R, B, D)
|
||||||
@ -239,6 +242,7 @@ def test_emformer_encoder_infer():
|
|||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
states = None
|
states = None
|
||||||
@ -284,6 +288,7 @@ def test_emformer_forward():
|
|||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
x = torch.randn(B, U + R + 3, num_features)
|
x = torch.randn(B, U + R + 3, num_features)
|
||||||
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
|
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
|
||||||
@ -324,6 +329,7 @@ def test_emformer_infer():
|
|||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
states = None
|
states = None
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
|
@ -137,6 +137,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Number of entries in the memory for the Emformer",
|
help="Number of entries in the memory for the Emformer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-conv",
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether use causal convolution.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -377,6 +384,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
chunk_length=params.chunk_length,
|
chunk_length=params.chunk_length,
|
||||||
right_context_length=params.right_context_length,
|
right_context_length=params.right_context_length,
|
||||||
max_memory_size=params.memory_size,
|
max_memory_size=params.memory_size,
|
||||||
|
causal=params.causal_conv,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user