Support cache of left context for causal convolution.

This commit is contained in:
yaozengwei 2022-04-12 20:13:51 +08:00
parent 651745b220
commit c2808f8541
3 changed files with 134 additions and 53 deletions

View File

@ -601,24 +601,8 @@ class EmformerLayer(nn.Module):
)
return right_context_utterance
def _apply_conv_module(
self,
right_context_utterance: torch.Tensor,
right_context_end_idx: int,
) -> torch.Tensor:
"""Apply convolution module on utterance."""
utterance = right_context_utterance[right_context_end_idx:]
right_context = right_context_utterance[:right_context_end_idx]
residual = utterance
utterance = self.norm_conv(utterance)
utterance = residual + self.dropout(self.conv_module(utterance))
right_context_utterance = torch.cat([right_context, utterance])
return right_context_utterance
def _apply_feed_forward_module(
self,
right_context_utterance: torch.Tensor,
self, right_context_utterance: torch.Tensor
) -> torch.Tensor:
"""Apply feed forward module."""
residual = right_context_utterance
@ -628,6 +612,39 @@ class EmformerLayer(nn.Module):
)
return right_context_utterance
def _apply_conv_module_forward(
self,
right_context_utterance: torch.Tensor,
right_context_end_idx: int,
) -> torch.Tensor:
"""Apply convolution module on utterance in non-infer mode."""
utterance = right_context_utterance[right_context_end_idx:]
right_context = right_context_utterance[:right_context_end_idx]
residual = utterance
utterance = self.norm_conv(utterance)
utterance, _ = self.conv_module(utterance)
utterance = residual + self.dropout(utterance)
right_context_utterance = torch.cat([right_context, utterance])
return right_context_utterance
def _apply_conv_module_infer(
self,
right_context_utterance: torch.Tensor,
right_context_end_idx: int,
conv_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Apply convolution module on utterance in infer mode."""
utterance = right_context_utterance[right_context_end_idx:]
right_context = right_context_utterance[:right_context_end_idx]
residual = utterance
utterance = self.norm_conv(utterance)
utterance, conv_cache = self.conv_module(utterance, conv_cache)
utterance = residual + self.dropout(utterance)
right_context_utterance = torch.cat([right_context, utterance])
return right_context_utterance, conv_cache
def _apply_attention_module_forward(
self,
right_context_utterance: torch.Tensor,
@ -790,7 +807,7 @@ class EmformerLayer(nn.Module):
attention_mask,
)
right_context_utterance = self._apply_conv_module(
right_context_utterance = self._apply_conv_module_forward(
right_context_utterance, right_context_end_idx
)
@ -812,6 +829,7 @@ class EmformerLayer(nn.Module):
right_context: torch.Tensor,
memory: torch.Tensor,
state: Optional[List[torch.Tensor]] = None,
conv_cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""Forward pass for inference.
@ -841,6 +859,8 @@ class EmformerLayer(nn.Module):
state (List[torch.Tensor], optional):
List of tensors representing layer internal state generated in
preceding computation. (default=None)
conv_cache (torch.Tensor, optional):
Cache tensor of left context for causal convolution.
Returns:
(Tensor, Tensor, List[torch.Tensor], Tensor):
@ -848,6 +868,7 @@ class EmformerLayer(nn.Module):
- output right_context, with shape (R, B, D);
- output memory, with shape (1, B, D) or (0, B, D).
- output state.
- updated conv_cache.
"""
right_context_utterance = torch.cat([right_context, utterance])
right_context_end_idx = right_context.size(0)
@ -868,8 +889,10 @@ class EmformerLayer(nn.Module):
state,
)
right_context_utterance = self._apply_conv_module(
right_context_utterance, right_context_end_idx
right_context_utterance, conv_cache = self._apply_conv_module_infer(
right_context_utterance,
right_context_end_idx,
conv_cache,
)
right_context_utterance = self._apply_feed_forward_module(
@ -885,6 +908,7 @@ class EmformerLayer(nn.Module):
output_right_context,
output_memory,
output_state,
conv_cache,
)
@ -1156,7 +1180,10 @@ class EmformerEncoder(nn.Module):
x: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
conv_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[
torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
]:
"""Forward pass for streaming inference.
B: batch size;
@ -1173,15 +1200,18 @@ class EmformerEncoder(nn.Module):
right_context at the end.
states (List[List[torch.Tensor]], optional):
Cached states from proceeding chunk's computation, where each
element (List[torch.Tensor]) corresponding to each emformer layer.
element (List[torch.Tensor]) corresponds to each emformer layer.
(default: None)
conv_caches (List[torch.Tensor], optional):
Cached tensors of left context for causal convolution, where each
element (Tensor) corresponds to each convolutional layer.
Returns:
(Tensor, Tensor, List[List[torch.Tensor]]):
(Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
- output utterance frames, with shape (U, B, D).
- output lengths, with shape (B,), without containing the
right_context at the end.
- updated states from current chunk's computation.
- updated convolution caches from current chunk.
"""
assert x.size(0) == self.chunk_length + self.right_context_length, (
"Per configured chunk_length and right_context_length, "
@ -1199,17 +1229,26 @@ class EmformerEncoder(nn.Module):
)
output = utterance
output_states: List[List[torch.Tensor]] = []
output_conv_caches: List[torch.Tensor] = []
for layer_idx, layer in enumerate(self.emformer_layers):
output, right_context, memory, output_state = layer.infer(
(
output,
right_context,
memory,
output_state,
output_conv_cache,
) = layer.infer(
output,
output_lengths,
right_context,
memory,
None if states is None else states[layer_idx],
None if conv_caches is None else conv_caches[layer_idx],
)
output_states.append(output_state)
output_conv_caches.append(output_conv_cache)
return output, output_lengths, output_states
return output, output_lengths, output_states, output_conv_caches
class Emformer(EncoderInterface):
@ -1328,6 +1367,7 @@ class Emformer(EncoderInterface):
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
conv_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
"""Forward pass for streaming inference.
@ -1345,8 +1385,11 @@ class Emformer(EncoderInterface):
right_context at the end.
states (List[List[torch.Tensor]], optional):
Cached states from proceeding chunk's computation, where each
element (List[torch.Tensor]) corresponding to each emformer layer.
element (List[torch.Tensor]) corresponds to each emformer layer.
(default: None)
conv_caches (List[torch.Tensor], optional):
Cached tensors of left context for causal convolution, where each
element (Tensor) corresponds to each convolutional layer.
Returns:
(Tensor, Tensor):
- output logits, with shape (B, T', D), where
@ -1354,6 +1397,7 @@ class Emformer(EncoderInterface):
- logits lengths, with shape (B,), without containing the
right_context at the end.
- updated states from current chunk's computation.
- updated convolution caches from current chunk.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -1364,14 +1408,17 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item()
output, output_lengths, output_states = self.encoder.infer(
x, x_lens, states
) # (T, N, C)
(
output,
output_lengths,
output_states,
output_conv_caches,
) = self.encoder.infer(x, x_lens, states, conv_caches)
logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, output_lengths, output_states
return logits, output_lengths, output_states, output_conv_caches
class ConvolutionModule(nn.Module):
@ -1437,28 +1484,50 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
x (torch.Tensor):
Input tensor (#time, batch, channels).
cache (torch.Tensor, optional):
Cached tensor for left padding (#batch, channels, cache_time).
Returns:
Tensor: Output tensor (#time, batch, channels).
A tuple of 2 tensors:
- output tensor (#time, batch, channels).
- updated cache tensor (#batch, channels, cache_time).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if self.left_padding > 0:
# manualy padding self.lorder zeros to the left
# make depthwise_conv causal
x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0)
if cache is None:
x = nn.functional.pad(
x, (self.left_padding, 0), "constant", 0.0
)
else:
assert cache.size(0) == x.size(0) # equal batch
assert cache.size(1) == x.size(1) # equal channel
assert cache.size(2) == self.left_padding
x = torch.cat([cache, x], dim=2)
new_cache = x[:, :, x.size(2) - self.left_padding :] # noqa
else:
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = None
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
@ -1469,7 +1538,7 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
return x.permute(2, 0, 1), new_cache
class Swish(torch.nn.Module):

View File

@ -133,6 +133,7 @@ def test_emformer_layer_infer():
R, L = 2, 5
chunk_length = 2
U = chunk_length
K = 3
for use_memory in [True, False]:
if use_memory:
@ -145,7 +146,7 @@ def test_emformer_layer_infer():
nhead=8,
dim_feedforward=1024,
chunk_length=chunk_length,
cnn_module_kernel=3,
cnn_module_kernel=K,
left_context_length=L,
max_memory_size=M,
causal=True,
@ -157,17 +158,15 @@ def test_emformer_layer_infer():
right_context = torch.randn(R, B, D)
memory = torch.randn(M, B, D)
state = None
conv_cache = None
(
output_utterance,
output_right_context,
output_memory,
output_state,
output_conv_cache,
) = layer.infer(
utterance,
lengths,
right_context,
memory,
state,
utterance, lengths, right_context, memory, state, conv_cache
)
assert output_utterance.shape == (U, B, D)
assert output_right_context.shape == (R, B, D)
@ -180,6 +179,7 @@ def test_emformer_layer_infer():
assert output_state[1].shape == (L, B, D)
assert output_state[2].shape == (L, B, D)
assert output_state[3].shape == (1, B)
assert output_conv_cache.shape == (B, D, K - 1)
def test_emformer_encoder_forward():
@ -226,6 +226,7 @@ def test_emformer_encoder_infer():
U = chunk_length
num_chunks = 3
num_encoder_layers = 2
K = 3
for use_memory in [True, False]:
if use_memory:
@ -238,7 +239,7 @@ def test_emformer_encoder_infer():
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=3,
cnn_module_kernel=K,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
@ -246,11 +247,14 @@ def test_emformer_encoder_infer():
)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
x = torch.randn(U + R, B, D)
lengths = torch.randint(1, U + R + 1, (B,))
lengths[0] = U + R
output, output_lengths, states = encoder.infer(x, lengths, states)
output, output_lengths, states, conv_caches = encoder.infer(
x, lengths, states, conv_caches
)
assert output.shape == (U, B, D)
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
assert len(states) == num_encoder_layers
@ -262,6 +266,8 @@ def test_emformer_encoder_infer():
assert torch.equal(
state[3], (chunk_idx + 1) * U * torch.ones_like(state[3])
)
for conv_cache in conv_caches:
assert conv_cache.shape == (B, D, K - 1)
def test_emformer_forward():
@ -312,6 +318,7 @@ def test_emformer_infer():
B, D = 2, 256
num_chunks = 3
num_encoder_layers = 2
K = 3
for use_memory in [True, False]:
if use_memory:
M = 3
@ -324,7 +331,7 @@ def test_emformer_infer():
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=3,
cnn_module_kernel=K,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
@ -332,11 +339,14 @@ def test_emformer_infer():
causal=True,
)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
x = torch.randn(B, U + R + 3, num_features)
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
x_lens[0] = U + R + 3
logits, output_lengths, states = model.infer(x, x_lens, states)
logits, output_lengths, states, conv_caches = model.infer(
x, x_lens, states, conv_caches
)
assert logits.shape == (B, U // 4, output_dim)
assert torch.equal(
output_lengths,
@ -352,6 +362,8 @@ def test_emformer_infer():
state[3],
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
)
for conv_cache in conv_caches:
assert conv_cache.shape == (B, D, K - 1)
if __name__ == "__main__":

View File

@ -139,7 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--causal-conv",
type=bool,
type=str2bool,
default=True,
help="Whether use causal convolution.",
)