mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Support cache of left context for causal convolution.
This commit is contained in:
parent
651745b220
commit
c2808f8541
@ -601,24 +601,8 @@ class EmformerLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
return right_context_utterance
|
return right_context_utterance
|
||||||
|
|
||||||
def _apply_conv_module(
|
|
||||||
self,
|
|
||||||
right_context_utterance: torch.Tensor,
|
|
||||||
right_context_end_idx: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Apply convolution module on utterance."""
|
|
||||||
utterance = right_context_utterance[right_context_end_idx:]
|
|
||||||
right_context = right_context_utterance[:right_context_end_idx]
|
|
||||||
|
|
||||||
residual = utterance
|
|
||||||
utterance = self.norm_conv(utterance)
|
|
||||||
utterance = residual + self.dropout(self.conv_module(utterance))
|
|
||||||
right_context_utterance = torch.cat([right_context, utterance])
|
|
||||||
return right_context_utterance
|
|
||||||
|
|
||||||
def _apply_feed_forward_module(
|
def _apply_feed_forward_module(
|
||||||
self,
|
self, right_context_utterance: torch.Tensor
|
||||||
right_context_utterance: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Apply feed forward module."""
|
"""Apply feed forward module."""
|
||||||
residual = right_context_utterance
|
residual = right_context_utterance
|
||||||
@ -628,6 +612,39 @@ class EmformerLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
return right_context_utterance
|
return right_context_utterance
|
||||||
|
|
||||||
|
def _apply_conv_module_forward(
|
||||||
|
self,
|
||||||
|
right_context_utterance: torch.Tensor,
|
||||||
|
right_context_end_idx: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply convolution module on utterance in non-infer mode."""
|
||||||
|
utterance = right_context_utterance[right_context_end_idx:]
|
||||||
|
right_context = right_context_utterance[:right_context_end_idx]
|
||||||
|
|
||||||
|
residual = utterance
|
||||||
|
utterance = self.norm_conv(utterance)
|
||||||
|
utterance, _ = self.conv_module(utterance)
|
||||||
|
utterance = residual + self.dropout(utterance)
|
||||||
|
right_context_utterance = torch.cat([right_context, utterance])
|
||||||
|
return right_context_utterance
|
||||||
|
|
||||||
|
def _apply_conv_module_infer(
|
||||||
|
self,
|
||||||
|
right_context_utterance: torch.Tensor,
|
||||||
|
right_context_end_idx: int,
|
||||||
|
conv_cache: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply convolution module on utterance in infer mode."""
|
||||||
|
utterance = right_context_utterance[right_context_end_idx:]
|
||||||
|
right_context = right_context_utterance[:right_context_end_idx]
|
||||||
|
|
||||||
|
residual = utterance
|
||||||
|
utterance = self.norm_conv(utterance)
|
||||||
|
utterance, conv_cache = self.conv_module(utterance, conv_cache)
|
||||||
|
utterance = residual + self.dropout(utterance)
|
||||||
|
right_context_utterance = torch.cat([right_context, utterance])
|
||||||
|
return right_context_utterance, conv_cache
|
||||||
|
|
||||||
def _apply_attention_module_forward(
|
def _apply_attention_module_forward(
|
||||||
self,
|
self,
|
||||||
right_context_utterance: torch.Tensor,
|
right_context_utterance: torch.Tensor,
|
||||||
@ -790,7 +807,7 @@ class EmformerLayer(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
right_context_utterance = self._apply_conv_module(
|
right_context_utterance = self._apply_conv_module_forward(
|
||||||
right_context_utterance, right_context_end_idx
|
right_context_utterance, right_context_end_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -812,6 +829,7 @@ class EmformerLayer(nn.Module):
|
|||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
state: Optional[List[torch.Tensor]] = None,
|
state: Optional[List[torch.Tensor]] = None,
|
||||||
|
conv_cache: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||||
"""Forward pass for inference.
|
"""Forward pass for inference.
|
||||||
|
|
||||||
@ -841,6 +859,8 @@ class EmformerLayer(nn.Module):
|
|||||||
state (List[torch.Tensor], optional):
|
state (List[torch.Tensor], optional):
|
||||||
List of tensors representing layer internal state generated in
|
List of tensors representing layer internal state generated in
|
||||||
preceding computation. (default=None)
|
preceding computation. (default=None)
|
||||||
|
conv_cache (torch.Tensor, optional):
|
||||||
|
Cache tensor of left context for causal convolution.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||||
@ -848,6 +868,7 @@ class EmformerLayer(nn.Module):
|
|||||||
- output right_context, with shape (R, B, D);
|
- output right_context, with shape (R, B, D);
|
||||||
- output memory, with shape (1, B, D) or (0, B, D).
|
- output memory, with shape (1, B, D) or (0, B, D).
|
||||||
- output state.
|
- output state.
|
||||||
|
- updated conv_cache.
|
||||||
"""
|
"""
|
||||||
right_context_utterance = torch.cat([right_context, utterance])
|
right_context_utterance = torch.cat([right_context, utterance])
|
||||||
right_context_end_idx = right_context.size(0)
|
right_context_end_idx = right_context.size(0)
|
||||||
@ -868,8 +889,10 @@ class EmformerLayer(nn.Module):
|
|||||||
state,
|
state,
|
||||||
)
|
)
|
||||||
|
|
||||||
right_context_utterance = self._apply_conv_module(
|
right_context_utterance, conv_cache = self._apply_conv_module_infer(
|
||||||
right_context_utterance, right_context_end_idx
|
right_context_utterance,
|
||||||
|
right_context_end_idx,
|
||||||
|
conv_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
right_context_utterance = self._apply_feed_forward_module(
|
right_context_utterance = self._apply_feed_forward_module(
|
||||||
@ -885,6 +908,7 @@ class EmformerLayer(nn.Module):
|
|||||||
output_right_context,
|
output_right_context,
|
||||||
output_memory,
|
output_memory,
|
||||||
output_state,
|
output_state,
|
||||||
|
conv_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1156,7 +1180,10 @@ class EmformerEncoder(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
states: Optional[List[List[torch.Tensor]]] = None,
|
states: Optional[List[List[torch.Tensor]]] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
|
||||||
|
]:
|
||||||
"""Forward pass for streaming inference.
|
"""Forward pass for streaming inference.
|
||||||
|
|
||||||
B: batch size;
|
B: batch size;
|
||||||
@ -1173,15 +1200,18 @@ class EmformerEncoder(nn.Module):
|
|||||||
right_context at the end.
|
right_context at the end.
|
||||||
states (List[List[torch.Tensor]], optional):
|
states (List[List[torch.Tensor]], optional):
|
||||||
Cached states from proceeding chunk's computation, where each
|
Cached states from proceeding chunk's computation, where each
|
||||||
element (List[torch.Tensor]) corresponding to each emformer layer.
|
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||||
(default: None)
|
(default: None)
|
||||||
|
conv_caches (List[torch.Tensor], optional):
|
||||||
|
Cached tensors of left context for causal convolution, where each
|
||||||
|
element (Tensor) corresponds to each convolutional layer.
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor, List[List[torch.Tensor]]):
|
(Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
|
||||||
- output utterance frames, with shape (U, B, D).
|
- output utterance frames, with shape (U, B, D).
|
||||||
- output lengths, with shape (B,), without containing the
|
- output lengths, with shape (B,), without containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
- updated states from current chunk's computation.
|
- updated states from current chunk's computation.
|
||||||
|
- updated convolution caches from current chunk.
|
||||||
"""
|
"""
|
||||||
assert x.size(0) == self.chunk_length + self.right_context_length, (
|
assert x.size(0) == self.chunk_length + self.right_context_length, (
|
||||||
"Per configured chunk_length and right_context_length, "
|
"Per configured chunk_length and right_context_length, "
|
||||||
@ -1199,17 +1229,26 @@ class EmformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
output = utterance
|
output = utterance
|
||||||
output_states: List[List[torch.Tensor]] = []
|
output_states: List[List[torch.Tensor]] = []
|
||||||
|
output_conv_caches: List[torch.Tensor] = []
|
||||||
for layer_idx, layer in enumerate(self.emformer_layers):
|
for layer_idx, layer in enumerate(self.emformer_layers):
|
||||||
output, right_context, memory, output_state = layer.infer(
|
(
|
||||||
|
output,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
output_state,
|
||||||
|
output_conv_cache,
|
||||||
|
) = layer.infer(
|
||||||
output,
|
output,
|
||||||
output_lengths,
|
output_lengths,
|
||||||
right_context,
|
right_context,
|
||||||
memory,
|
memory,
|
||||||
None if states is None else states[layer_idx],
|
None if states is None else states[layer_idx],
|
||||||
|
None if conv_caches is None else conv_caches[layer_idx],
|
||||||
)
|
)
|
||||||
output_states.append(output_state)
|
output_states.append(output_state)
|
||||||
|
output_conv_caches.append(output_conv_cache)
|
||||||
|
|
||||||
return output, output_lengths, output_states
|
return output, output_lengths, output_states, output_conv_caches
|
||||||
|
|
||||||
|
|
||||||
class Emformer(EncoderInterface):
|
class Emformer(EncoderInterface):
|
||||||
@ -1328,6 +1367,7 @@ class Emformer(EncoderInterface):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
states: Optional[List[List[torch.Tensor]]] = None,
|
states: Optional[List[List[torch.Tensor]]] = None,
|
||||||
|
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
||||||
"""Forward pass for streaming inference.
|
"""Forward pass for streaming inference.
|
||||||
|
|
||||||
@ -1345,8 +1385,11 @@ class Emformer(EncoderInterface):
|
|||||||
right_context at the end.
|
right_context at the end.
|
||||||
states (List[List[torch.Tensor]], optional):
|
states (List[List[torch.Tensor]], optional):
|
||||||
Cached states from proceeding chunk's computation, where each
|
Cached states from proceeding chunk's computation, where each
|
||||||
element (List[torch.Tensor]) corresponding to each emformer layer.
|
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||||
(default: None)
|
(default: None)
|
||||||
|
conv_caches (List[torch.Tensor], optional):
|
||||||
|
Cached tensors of left context for causal convolution, where each
|
||||||
|
element (Tensor) corresponds to each convolutional layer.
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor):
|
(Tensor, Tensor):
|
||||||
- output logits, with shape (B, T', D), where
|
- output logits, with shape (B, T', D), where
|
||||||
@ -1354,6 +1397,7 @@ class Emformer(EncoderInterface):
|
|||||||
- logits lengths, with shape (B,), without containing the
|
- logits lengths, with shape (B,), without containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
- updated states from current chunk's computation.
|
- updated states from current chunk's computation.
|
||||||
|
- updated convolution caches from current chunk.
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
@ -1364,14 +1408,17 @@ class Emformer(EncoderInterface):
|
|||||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == x_lens.max().item()
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
output, output_lengths, output_states = self.encoder.infer(
|
(
|
||||||
x, x_lens, states
|
output,
|
||||||
) # (T, N, C)
|
output_lengths,
|
||||||
|
output_states,
|
||||||
|
output_conv_caches,
|
||||||
|
) = self.encoder.infer(x, x_lens, states, conv_caches)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(output)
|
logits = self.encoder_output_layer(output)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return logits, output_lengths, output_states
|
return logits, output_lengths, output_states, output_conv_caches
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
@ -1437,28 +1484,50 @@ class ConvolutionModule(nn.Module):
|
|||||||
)
|
)
|
||||||
self.activation = Swish()
|
self.activation = Swish()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cache: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor (#time, batch, channels).
|
x (torch.Tensor):
|
||||||
|
Input tensor (#time, batch, channels).
|
||||||
|
cache (torch.Tensor, optional):
|
||||||
|
Cached tensor for left padding (#batch, channels, cache_time).
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Output tensor (#time, batch, channels).
|
A tuple of 2 tensors:
|
||||||
|
- output tensor (#time, batch, channels).
|
||||||
|
- updated cache tensor (#batch, channels, cache_time).
|
||||||
"""
|
"""
|
||||||
# exchange the temporal dimension and the feature dimension
|
# exchange the temporal dimension and the feature dimension
|
||||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
|
|
||||||
# GLU mechanism
|
|
||||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
|
||||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
if self.left_padding > 0:
|
if self.left_padding > 0:
|
||||||
# manualy padding self.lorder zeros to the left
|
# manualy padding self.lorder zeros to the left
|
||||||
# make depthwise_conv causal
|
# make depthwise_conv causal
|
||||||
x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0)
|
if cache is None:
|
||||||
|
x = nn.functional.pad(
|
||||||
|
x, (self.left_padding, 0), "constant", 0.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert cache.size(0) == x.size(0) # equal batch
|
||||||
|
assert cache.size(1) == x.size(1) # equal channel
|
||||||
|
assert cache.size(2) == self.left_padding
|
||||||
|
x = torch.cat([cache, x], dim=2)
|
||||||
|
new_cache = x[:, :, x.size(2) - self.left_padding :] # noqa
|
||||||
|
else:
|
||||||
|
# It's better we just return None if no cache is requried,
|
||||||
|
# However, for JIT export, here we just fake one tensor instead of
|
||||||
|
# None.
|
||||||
|
new_cache = None
|
||||||
|
|
||||||
|
# GLU mechanism
|
||||||
|
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||||
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
# x is (batch, channels, time)
|
# x is (batch, channels, time)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
@ -1469,7 +1538,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1), new_cache
|
||||||
|
|
||||||
|
|
||||||
class Swish(torch.nn.Module):
|
class Swish(torch.nn.Module):
|
||||||
|
@ -133,6 +133,7 @@ def test_emformer_layer_infer():
|
|||||||
R, L = 2, 5
|
R, L = 2, 5
|
||||||
chunk_length = 2
|
chunk_length = 2
|
||||||
U = chunk_length
|
U = chunk_length
|
||||||
|
K = 3
|
||||||
|
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
@ -145,7 +146,7 @@ def test_emformer_layer_infer():
|
|||||||
nhead=8,
|
nhead=8,
|
||||||
dim_feedforward=1024,
|
dim_feedforward=1024,
|
||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
cnn_module_kernel=3,
|
cnn_module_kernel=K,
|
||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
causal=True,
|
causal=True,
|
||||||
@ -157,17 +158,15 @@ def test_emformer_layer_infer():
|
|||||||
right_context = torch.randn(R, B, D)
|
right_context = torch.randn(R, B, D)
|
||||||
memory = torch.randn(M, B, D)
|
memory = torch.randn(M, B, D)
|
||||||
state = None
|
state = None
|
||||||
|
conv_cache = None
|
||||||
(
|
(
|
||||||
output_utterance,
|
output_utterance,
|
||||||
output_right_context,
|
output_right_context,
|
||||||
output_memory,
|
output_memory,
|
||||||
output_state,
|
output_state,
|
||||||
|
output_conv_cache,
|
||||||
) = layer.infer(
|
) = layer.infer(
|
||||||
utterance,
|
utterance, lengths, right_context, memory, state, conv_cache
|
||||||
lengths,
|
|
||||||
right_context,
|
|
||||||
memory,
|
|
||||||
state,
|
|
||||||
)
|
)
|
||||||
assert output_utterance.shape == (U, B, D)
|
assert output_utterance.shape == (U, B, D)
|
||||||
assert output_right_context.shape == (R, B, D)
|
assert output_right_context.shape == (R, B, D)
|
||||||
@ -180,6 +179,7 @@ def test_emformer_layer_infer():
|
|||||||
assert output_state[1].shape == (L, B, D)
|
assert output_state[1].shape == (L, B, D)
|
||||||
assert output_state[2].shape == (L, B, D)
|
assert output_state[2].shape == (L, B, D)
|
||||||
assert output_state[3].shape == (1, B)
|
assert output_state[3].shape == (1, B)
|
||||||
|
assert output_conv_cache.shape == (B, D, K - 1)
|
||||||
|
|
||||||
|
|
||||||
def test_emformer_encoder_forward():
|
def test_emformer_encoder_forward():
|
||||||
@ -226,6 +226,7 @@ def test_emformer_encoder_infer():
|
|||||||
U = chunk_length
|
U = chunk_length
|
||||||
num_chunks = 3
|
num_chunks = 3
|
||||||
num_encoder_layers = 2
|
num_encoder_layers = 2
|
||||||
|
K = 3
|
||||||
|
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
@ -238,7 +239,7 @@ def test_emformer_encoder_infer():
|
|||||||
d_model=D,
|
d_model=D,
|
||||||
dim_feedforward=1024,
|
dim_feedforward=1024,
|
||||||
num_encoder_layers=num_encoder_layers,
|
num_encoder_layers=num_encoder_layers,
|
||||||
cnn_module_kernel=3,
|
cnn_module_kernel=K,
|
||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
@ -246,11 +247,14 @@ def test_emformer_encoder_infer():
|
|||||||
)
|
)
|
||||||
|
|
||||||
states = None
|
states = None
|
||||||
|
conv_caches = None
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
x = torch.randn(U + R, B, D)
|
x = torch.randn(U + R, B, D)
|
||||||
lengths = torch.randint(1, U + R + 1, (B,))
|
lengths = torch.randint(1, U + R + 1, (B,))
|
||||||
lengths[0] = U + R
|
lengths[0] = U + R
|
||||||
output, output_lengths, states = encoder.infer(x, lengths, states)
|
output, output_lengths, states, conv_caches = encoder.infer(
|
||||||
|
x, lengths, states, conv_caches
|
||||||
|
)
|
||||||
assert output.shape == (U, B, D)
|
assert output.shape == (U, B, D)
|
||||||
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
|
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
|
||||||
assert len(states) == num_encoder_layers
|
assert len(states) == num_encoder_layers
|
||||||
@ -262,6 +266,8 @@ def test_emformer_encoder_infer():
|
|||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
state[3], (chunk_idx + 1) * U * torch.ones_like(state[3])
|
state[3], (chunk_idx + 1) * U * torch.ones_like(state[3])
|
||||||
)
|
)
|
||||||
|
for conv_cache in conv_caches:
|
||||||
|
assert conv_cache.shape == (B, D, K - 1)
|
||||||
|
|
||||||
|
|
||||||
def test_emformer_forward():
|
def test_emformer_forward():
|
||||||
@ -312,6 +318,7 @@ def test_emformer_infer():
|
|||||||
B, D = 2, 256
|
B, D = 2, 256
|
||||||
num_chunks = 3
|
num_chunks = 3
|
||||||
num_encoder_layers = 2
|
num_encoder_layers = 2
|
||||||
|
K = 3
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
if use_memory:
|
if use_memory:
|
||||||
M = 3
|
M = 3
|
||||||
@ -324,7 +331,7 @@ def test_emformer_infer():
|
|||||||
subsampling_factor=4,
|
subsampling_factor=4,
|
||||||
d_model=D,
|
d_model=D,
|
||||||
num_encoder_layers=num_encoder_layers,
|
num_encoder_layers=num_encoder_layers,
|
||||||
cnn_module_kernel=3,
|
cnn_module_kernel=K,
|
||||||
left_context_length=L,
|
left_context_length=L,
|
||||||
right_context_length=R,
|
right_context_length=R,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
@ -332,11 +339,14 @@ def test_emformer_infer():
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
states = None
|
states = None
|
||||||
|
conv_caches = None
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
x = torch.randn(B, U + R + 3, num_features)
|
x = torch.randn(B, U + R + 3, num_features)
|
||||||
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
|
x_lens = torch.randint(1, U + R + 3 + 1, (B,))
|
||||||
x_lens[0] = U + R + 3
|
x_lens[0] = U + R + 3
|
||||||
logits, output_lengths, states = model.infer(x, x_lens, states)
|
logits, output_lengths, states, conv_caches = model.infer(
|
||||||
|
x, x_lens, states, conv_caches
|
||||||
|
)
|
||||||
assert logits.shape == (B, U // 4, output_dim)
|
assert logits.shape == (B, U // 4, output_dim)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
output_lengths,
|
output_lengths,
|
||||||
@ -352,6 +362,8 @@ def test_emformer_infer():
|
|||||||
state[3],
|
state[3],
|
||||||
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
|
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
|
||||||
)
|
)
|
||||||
|
for conv_cache in conv_caches:
|
||||||
|
assert conv_cache.shape == (B, D, K - 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -139,7 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal-conv",
|
"--causal-conv",
|
||||||
type=bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Whether use causal convolution.",
|
help="Whether use causal convolution.",
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user