add conv module

This commit is contained in:
yaozengwei 2022-05-13 22:23:39 +08:00
parent 2cfb2f58f0
commit 3360dc5afc
2 changed files with 343 additions and 0 deletions

View File

@ -153,9 +153,298 @@ class RelPositionalEncoding(torch.nn.Module):
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
class ConvolutionModule(nn.Module):
"""ConvolutionModule.
Modified from https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa
Args:
chunk_length (int):
Length of each chunk.
right_context_length (int):
Length of right context.
channels (int):
The number of channels of conv layers.
kernel_size (int):
Kernerl size of conv layers.
bias (bool):
Whether to use bias in conv layers (default=True).
"""
def __init__(
self,
chunk_length: int,
right_context_length: int,
channels: int,
kernel_size: int,
bias: bool = True,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.chunk_length = chunk_length
self.right_context_length = right_context_length
self.channels = channels
self.pointwise_conv1 = ScaledConv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
# After pointwise_conv1 we put x through a gated linear unit
# (nn.functional.glu).
# For most layers the normal rms value of channels of x seems to be in
# the range 1 to 4, but sometimes, for some reason, for layer 0 the rms
# ends up being very large, between 50 and 100 for different channels.
# This will cause very peaky and sparse derivatives for the sigmoid
# gating function, which will tend to make the loss function not learn
# effectively. (for most layers the average absolute values are in the
# range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for
# different layers, which likely breaks down as 0.5 for the "linear"
# half and 0.2 to 0.3 for the part that goes into the sigmoid.
# The idea is that if we constrain the rms values to a reasonable range
# via a constraint of max_abs=10.0, it will be in a better position to
# start learning something, i.e. to latch onto the correct range.
self.deriv_balancer1 = ActivationBalancer(
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
# make it causal by padding cached (kernel_size - 1) frames on the left
self.cache_size = kernel_size - 1
self.depthwise_conv = ScaledConv1d(
channels,
channels,
kernel_size,
stride=1,
padding=0,
groups=channels,
bias=bias,
)
self.deriv_balancer2 = ActivationBalancer(
channel_dim=1, min_positive=0.05, max_positive=1.0
)
self.activation = DoubleSwish()
self.pointwise_conv2 = ScaledConv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
initial_scale=0.25,
)
def _split_right_context(
self,
pad_utterance: torch.Tensor,
right_context: torch.Tensor,
) -> torch.Tensor:
"""
Args:
pad_utterance:
Its shape is (cache_size + U, B, D).
right_context:
Its shape is (R, B, D).
Returns:
Right context segments padding with corresponding context.
Its shape is (num_segs * B, D, cache_size + right_context_length).
"""
U_, B, D = pad_utterance.size()
R = right_context.size(0)
assert self.right_context_length != 0
assert R % self.right_context_length == 0
num_chunks = R // self.right_context_length
right_context = right_context.reshape(
num_chunks, self.right_context_length, B, D
)
right_context = right_context.permute(0, 2, 1, 3).reshape(
num_chunks * B, self.right_context_length, D
)
padding = []
for idx in range(num_chunks):
end_idx = min(U_, self.cache_size + (idx + 1) * self.chunk_length)
start_idx = end_idx - self.cache_size
padding.append(pad_utterance[start_idx:end_idx])
padding = torch.cat(padding, dim=1).permute(1, 0, 2)
# (num_segs * B, cache_size, D)
pad_right_context = torch.cat([padding, right_context], dim=1)
# (num_segs * B, cache_size + right_context_length, D)
return pad_right_context.permute(0, 2, 1)
def _merge_right_context(
self, right_context: torch.Tensor, B: int
) -> torch.Tensor:
"""
Args:
right_context:
Right context segments.
It shape is (num_segs * B, D, right_context_length).
B:
Batch size.
Returns:
A tensor of shape (B, D, R), where
R = num_segs * right_context_length.
"""
right_context = right_context.reshape(
-1, B, self.channels, self.right_context_length
)
right_context = right_context.permute(1, 2, 0, 3)
right_context = right_context.reshape(B, self.channels, -1)
return right_context
def forward(
self,
utterance: torch.Tensor,
right_context: torch.Tensor,
cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Causal convolution module applied on both utterance and right_context.
Args:
utterance (torch.Tensor):
Utterance tensor of shape (U, B, D).
right_context (torch.Tensor):
Right context tensor of shape (R, B, D).
cache (torch.Tensor, optional):
Cached tensor for left padding of shape (B, D, cache_size).
Returns:
A tuple of 3 tensors:
- output utterance of shape (U, B, D).
- output right_context of shape (R, B, D).
- updated cache tensor of shape (B, D, cache_size).
"""
U, B, D = utterance.size()
R, _, _ = right_context.size()
# point-wise conv and GLU mechanism
x = torch.cat([right_context, utterance], dim=0) # (R + U, B, D)
x = x.permute(1, 2, 0) # (B, D, R + U)
x = self.pointwise_conv1(x) # (B, 2 * D, R + U)
x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=1) # (B, D, R + U)
utterance = x[:, :, R:] # (B, D, U)
right_context = x[:, :, :R] # (B, D, R)
if cache is None:
cache = torch.zeros(
B, D, self.cache_size, device=x.device, dtype=x.dtype
)
else:
assert cache.shape == (B, D, self.cache_size), cache.shape
pad_utterance = torch.cat(
[cache, utterance], dim=2
) # (B, D, cache + U)
# update cache
new_cache = pad_utterance[:, :, -self.cache_size :]
# depth-wise conv on utterance
utterance = self.depthwise_conv(pad_utterance) # (B, D, U)
if self.right_context_length > 0:
# depth-wise conv on right_context
pad_right_context = self._split_right_context(
pad_utterance.permute(2, 0, 1), right_context.permute(2, 0, 1)
) # (num_segs * B, D, cache_size + right_context_length)
right_context = self.depthwise_conv(
pad_right_context
) # (num_segs * B, D, right_context_length)
right_context = self._merge_right_context(
right_context, B
) # (B, D, R)
x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U)
x = self.deriv_balancer2(x)
x = self.activation(x)
# point-wise conv
x = self.pointwise_conv2(x) # (B, D, R + U)
right_context = x[:, :, :R] # (B, D, R)
utterance = x[:, :, R:] # (B, D, U)
return (
utterance.permute(2, 0, 1),
right_context.permute(2, 0, 1),
new_cache,
)
def infer(
self,
utterance: torch.Tensor,
right_context: torch.Tensor,
cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Causal convolution module applied on both utterance and right_context.
Args:
utterance (torch.Tensor):
Utterance tensor of shape (U, B, D).
right_context (torch.Tensor):
Right context tensor of shape (R, B, D).
cache (torch.Tensor, optional):
Cached tensor for left padding of shape (B, D, cache_size).
Returns:
A tuple of 3 tensors:
- output utterance of shape (U, B, D).
- output right_context of shape (R, B, D).
- updated cache tensor of shape (B, D, cache_size).
"""
U, B, D = utterance.size()
R, _, _ = right_context.size()
# point-wise conv
x = torch.cat([utterance, right_context], dim=0) # (U + R, B, D)
x = x.permute(1, 2, 0) # (B, D, U + R)
x = self.pointwise_conv1(x) # (B, 2 * D, U + R)
x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=1) # (B, D, U + R)
if cache is None:
cache = torch.zeros(
B, D, self.cache_size, device=x.device, dtype=x.dtype
)
else:
assert cache.shape == (B, D, self.cache_size), cache.shape
x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R)
# update cache
new_cache = x[:, :, -R - self.cache_size:-R]
# 1-D depth-wise conv
x = self.depthwise_conv(x) # (B, D, U + R)
x = self.deriv_balancer2(x)
x = self.activation(x)
# point-wise conv
x = self.pointwise_conv2(x) # (B, D, U + R)
utterance = x[:, :, :U] # (B, D, U)
right_context = x[:, :, U:] # (B, D, R)
return (
utterance.permute(2, 0, 1),
right_context.permute(2, 0, 1),
new_cache,
)
class EmformerAttention(nn.Module): class EmformerAttention(nn.Module):
r"""Emformer layer attention module. r"""Emformer layer attention module.
Relative positional encoding is applied in this module, which is difference
from https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py # noqa
Args: Args:
embed_dim (int): embed_dim (int):
Embedding dimension. Embedding dimension.

View File

@ -102,7 +102,61 @@ def test_emformer_attention_infer():
assert next_val.shape == (L + U, B, D) assert next_val.shape == (L + U, B, D)
def test_convolution_module_forward():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length, right_context_length, D, kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
cache = torch.randn(B, D, kernel_size - 1)
utterance, right_context, new_cache = conv_module(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert new_cache.shape == (B, D, kernel_size - 1)
def test_convolution_module_infer():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 1
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length, right_context_length, D, kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
cache = torch.randn(B, D, kernel_size - 1)
utterance, right_context, new_cache = conv_module.infer(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D)
assert right_context.shape == (R, B, D)
assert new_cache.shape == (B, D, kernel_size - 1)
if __name__ == "__main__": if __name__ == "__main__":
test_rel_positional_encoding() test_rel_positional_encoding()
test_emformer_attention_forward() test_emformer_attention_forward()
test_emformer_attention_infer() test_emformer_attention_infer()
test_convolution_module_forward()
test_convolution_module_infer()