mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add EmformerEncoderLayer module
This commit is contained in:
parent
3360dc5afc
commit
b265a5c875
@ -419,7 +419,7 @@ class ConvolutionModule(nn.Module):
|
||||
assert cache.shape == (B, D, self.cache_size), cache.shape
|
||||
x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R)
|
||||
# update cache
|
||||
new_cache = x[:, :, -R - self.cache_size:-R]
|
||||
new_cache = x[:, :, -R - self.cache_size : -R]
|
||||
|
||||
# 1-D depth-wise conv
|
||||
x = self.depthwise_conv(x) # (B, D, U + R)
|
||||
@ -572,7 +572,7 @@ class EmformerAttention(nn.Module):
|
||||
Args:
|
||||
x: Input tensor, of shape (B, nhead, U, PE).
|
||||
U is the length of query vector.
|
||||
For training mode, PE = 2 * U - 1;
|
||||
For training and validation mode, PE = 2 * U - 1;
|
||||
for inference mode, PE = L + 2 * U - 1.
|
||||
|
||||
Returns:
|
||||
@ -666,7 +666,7 @@ class EmformerAttention(nn.Module):
|
||||
L = left_context_key.size(0)
|
||||
assert PE == L + 2 * U - 1
|
||||
else:
|
||||
# training mode
|
||||
# training and validation mode
|
||||
assert PE == 2 * U - 1
|
||||
pos_emb = (
|
||||
self.linear_pos(pos_emb)
|
||||
@ -679,7 +679,7 @@ class EmformerAttention(nn.Module):
|
||||
) # (B, nhead, U, PE)
|
||||
# rel-shift operation
|
||||
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
||||
# (B, nhead, U, U) for training mode;
|
||||
# (B, nhead, U, U) for training and validation mode;
|
||||
# (B, nhead, U, L + U) for inference mode.
|
||||
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
||||
B * self.nhead, U, -1
|
||||
@ -730,7 +730,7 @@ class EmformerAttention(nn.Module):
|
||||
pos_emb: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO: Modify docs.
|
||||
"""Forward pass for training mode.
|
||||
"""Forward pass for training and validation mode.
|
||||
|
||||
B: batch size;
|
||||
D: embedding dimension;
|
||||
@ -922,3 +922,464 @@ class EmformerAttention(nn.Module):
|
||||
key[M + R :],
|
||||
value[M + R :],
|
||||
)
|
||||
|
||||
|
||||
class EmformerEncoderLayer(nn.Module):
|
||||
"""Emformer layer that constitutes Emformer.
|
||||
|
||||
Args:
|
||||
d_model (int):
|
||||
Input dimension.
|
||||
nhead (int):
|
||||
Number of attention heads.
|
||||
dim_feedforward (int):
|
||||
Hidden layer dimension of feedforward network.
|
||||
chunk_length (int):
|
||||
Length of each input segment.
|
||||
dropout (float, optional):
|
||||
Dropout probability. (Default: 0.0)
|
||||
layer_dropout (float, optional):
|
||||
Layer dropout probability. (Default: 0.0)
|
||||
cnn_module_kernel (int):
|
||||
Kernel size of convolution module.
|
||||
left_context_length (int, optional):
|
||||
Length of left context. (Default: 0)
|
||||
right_context_length (int, optional):
|
||||
Length of right context. (Default: 0)
|
||||
max_memory_size (int, optional):
|
||||
Maximum number of memory elements to use. (Default: 0)
|
||||
tanh_on_mem (bool, optional):
|
||||
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||
negative_inf (float, optional):
|
||||
Value to use for negative infinity in attention weights. (Default: -1e8)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int,
|
||||
chunk_length: int,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
left_context_length: int = 0,
|
||||
right_context_length: int = 0,
|
||||
max_memory_size: int = 0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = EmformerAttention(
|
||||
embed_dim=d_model,
|
||||
nhead=nhead,
|
||||
dropout=dropout,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
)
|
||||
self.summary_op = nn.AvgPool1d(
|
||||
kernel_size=chunk_length, stride=chunk_length, ceil_mode=True
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
d_model,
|
||||
cnn_module_kernel,
|
||||
)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
# try to ensure the output is close to zero-mean
|
||||
# (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.layer_dropout = layer_dropout
|
||||
self.left_context_length = left_context_length
|
||||
self.chunk_length = chunk_length
|
||||
self.max_memory_size = max_memory_size
|
||||
self.d_model = d_model
|
||||
self.use_memory = max_memory_size > 0
|
||||
|
||||
def _init_state(
|
||||
self, batch_size: int, device: Optional[torch.device]
|
||||
) -> List[torch.Tensor]:
|
||||
"""Initialize states with zeros."""
|
||||
empty_memory = torch.zeros(
|
||||
self.max_memory_size, batch_size, self.d_model, device=device
|
||||
)
|
||||
left_context_key = torch.zeros(
|
||||
self.left_context_length, batch_size, self.d_model, device=device
|
||||
)
|
||||
left_context_val = torch.zeros(
|
||||
self.left_context_length, batch_size, self.d_model, device=device
|
||||
)
|
||||
past_length = torch.zeros(
|
||||
1, batch_size, dtype=torch.int32, device=device
|
||||
)
|
||||
return [empty_memory, left_context_key, left_context_val, past_length]
|
||||
|
||||
def _unpack_state(
|
||||
self, state: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Unpack cached states including:
|
||||
1) output memory from previous chunks in the lower layer;
|
||||
2) attention key and value of left context from proceeding chunk's
|
||||
computation.
|
||||
"""
|
||||
past_length = state[3][0][0].item()
|
||||
past_left_context_length = min(self.left_context_length, past_length)
|
||||
past_memory_length = min(
|
||||
self.max_memory_size, math.ceil(past_length / self.chunk_length)
|
||||
)
|
||||
memory_start_idx = self.max_memory_size - past_memory_length
|
||||
pre_memory = state[0][memory_start_idx:]
|
||||
left_context_start_idx = (
|
||||
self.left_context_length - past_left_context_length
|
||||
)
|
||||
left_context_key = state[1][left_context_start_idx:]
|
||||
left_context_val = state[2][left_context_start_idx:]
|
||||
return pre_memory, left_context_key, left_context_val
|
||||
|
||||
def _pack_state(
|
||||
self,
|
||||
next_key: torch.Tensor,
|
||||
next_val: torch.Tensor,
|
||||
update_length: int,
|
||||
memory: torch.Tensor,
|
||||
state: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
"""Pack updated states including:
|
||||
1) output memory of current chunk in the lower layer;
|
||||
2) attention key and value in current chunk's computation, which would
|
||||
be resued in next chunk's computation.
|
||||
3) length of current chunk.
|
||||
"""
|
||||
new_memory = torch.cat([state[0], memory])
|
||||
new_key = torch.cat([state[1], next_key])
|
||||
new_val = torch.cat([state[2], next_val])
|
||||
memory_start_idx = new_memory.size(0) - self.max_memory_size
|
||||
state[0] = new_memory[memory_start_idx:]
|
||||
key_start_idx = new_key.size(0) - self.left_context_length
|
||||
state[1] = new_key[key_start_idx:]
|
||||
val_start_idx = new_val.size(0) - self.left_context_length
|
||||
state[2] = new_val[val_start_idx:]
|
||||
state[3] = state[3] + update_length
|
||||
return state
|
||||
|
||||
def _apply_conv_module_forward(
|
||||
self,
|
||||
right_context_utterance: torch.Tensor,
|
||||
R: int,
|
||||
) -> torch.Tensor:
|
||||
"""Apply convolution module in training and validation mode."""
|
||||
utterance = right_context_utterance[R:]
|
||||
right_context = right_context_utterance[:R]
|
||||
utterance, right_context, _ = self.conv_module(utterance, right_context)
|
||||
right_context_utterance = torch.cat([right_context, utterance])
|
||||
return right_context_utterance
|
||||
|
||||
def _apply_conv_module_infer(
|
||||
self,
|
||||
right_context_utterance: torch.Tensor,
|
||||
R: int,
|
||||
conv_cache: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Apply convolution module on utterance in inference mode."""
|
||||
utterance = right_context_utterance[R:]
|
||||
right_context = right_context_utterance[:R]
|
||||
utterance, right_context, conv_cache = self.conv_module.infer(
|
||||
utterance, right_context, conv_cache
|
||||
)
|
||||
right_context_utterance = torch.cat([right_context, utterance])
|
||||
return right_context_utterance, conv_cache
|
||||
|
||||
def _apply_attention_module_forward(
|
||||
self,
|
||||
right_context_utterance: torch.Tensor,
|
||||
R: int,
|
||||
lengths: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply attention module in training and validation mode."""
|
||||
if attention_mask is None:
|
||||
raise ValueError(
|
||||
"attention_mask must be not None in training or validation mode." # noqa
|
||||
)
|
||||
utterance = right_context_utterance[R:]
|
||||
right_context = right_context_utterance[:R]
|
||||
|
||||
if self.use_memory:
|
||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
)
|
||||
else:
|
||||
summary = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
output_right_context_utterance, output_memory = self.attention(
|
||||
utterance=utterance,
|
||||
lengths=lengths,
|
||||
right_context=right_context,
|
||||
summary=summary,
|
||||
memory=memory,
|
||||
attention_mask=attention_mask,
|
||||
pos_emb=pos_emb,
|
||||
)
|
||||
|
||||
return output_right_context_utterance, output_memory
|
||||
|
||||
def _apply_attention_module_infer(
|
||||
self,
|
||||
right_context_utterance: torch.Tensor,
|
||||
R: int,
|
||||
lengths: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
state: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""Apply attention module in inference mode.
|
||||
1) Unpack cached states including:
|
||||
- memory from previous chunks in the lower layer;
|
||||
- attention key and value of left context from proceeding
|
||||
chunk's compuation;
|
||||
2) Apply attention computation;
|
||||
3) Pack updated states including:
|
||||
- output memory of current chunk in the lower layer;
|
||||
- attention key and value in current chunk's computation, which would
|
||||
be resued in next chunk's computation.
|
||||
- length of current chunk.
|
||||
"""
|
||||
utterance = right_context_utterance[R:]
|
||||
right_context = right_context_utterance[:R]
|
||||
|
||||
if state is None:
|
||||
state = self._init_state(utterance.size(1), device=utterance.device)
|
||||
pre_memory, left_context_key, left_context_val = self._unpack_state(
|
||||
state
|
||||
)
|
||||
if self.use_memory:
|
||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
)
|
||||
summary = summary[:1]
|
||||
else:
|
||||
summary = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
|
||||
# for query of [utterance] (i), key-value [left_context, utterance] (j),
|
||||
# the max relative distance i - j is L + U - 1
|
||||
# the min relative distance i - j is -(U - 1)
|
||||
L = left_context_key.size(0) # L <= left_context_length
|
||||
U = utterance.size(0)
|
||||
PE = L + 2 * U - 1
|
||||
tot_PE = self.left_context_length + 2 * U - 1
|
||||
assert pos_emb.size(0) == tot_PE
|
||||
pos_emb = pos_emb[tot_PE - PE :]
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
next_key,
|
||||
next_val,
|
||||
) = self.attention.infer(
|
||||
utterance=utterance,
|
||||
lengths=lengths,
|
||||
right_context=right_context,
|
||||
summary=summary,
|
||||
memory=pre_memory,
|
||||
left_context_key=left_context_key,
|
||||
left_context_val=left_context_val,
|
||||
pos_emb=pos_emb,
|
||||
)
|
||||
state = self._pack_state(
|
||||
next_key, next_val, utterance.size(0), memory, state
|
||||
)
|
||||
return output_right_context_utterance, output_memory, state
|
||||
|
||||
def forward(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
right_context: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
r"""Forward pass for training and validation mode.
|
||||
|
||||
B: batch size;
|
||||
D: embedding dimension;
|
||||
R: length of hard-copied right contexts;
|
||||
U: length of full utterance;
|
||||
M: length of memory vectors.
|
||||
|
||||
Args:
|
||||
utterance (torch.Tensor):
|
||||
Utterance frames, with shape (U, B, D).
|
||||
lengths (torch.Tensor):
|
||||
With shape (B,) and i-th element representing
|
||||
number of valid frames for i-th batch element in utterance.
|
||||
right_context (torch.Tensor):
|
||||
Right context frames, with shape (R, B, D).
|
||||
memory (torch.Tensor):
|
||||
Memory elements, with shape (M, B, D).
|
||||
It is an empty tensor without using memory.
|
||||
attention_mask (torch.Tensor):
|
||||
Attention mask for underlying attention module,
|
||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, P = 2*U-1.
|
||||
|
||||
Returns:
|
||||
A tuple containing 3 tensors:
|
||||
- output utterance, with shape (U, B, D).
|
||||
- output right context, with shape (R, B, D).
|
||||
- output memory, with shape (M, B, D).
|
||||
"""
|
||||
R = right_context.size(0)
|
||||
src = torch.cat([right_context, utterance])
|
||||
src_orig = src
|
||||
|
||||
warmup_scale = min(0.1 + warmup, 1.0)
|
||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||
# completely bypass it.
|
||||
if self.training:
|
||||
alpha = (
|
||||
warmup_scale
|
||||
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||
else 0.1
|
||||
)
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# emformer attention module
|
||||
src_att, output_memory = self._apply_attention_module_forward(
|
||||
src, R, lengths, memory, pos_emb, attention_mask
|
||||
)
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
src_conv = self._apply_conv_module_forward(src, R)
|
||||
src = src + self.dropout(src_conv)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
output_utterance = src[R:]
|
||||
output_right_context = src[:R]
|
||||
return output_utterance, output_right_context, output_memory
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
right_context: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
state: Optional[List[torch.Tensor]] = None,
|
||||
conv_cache: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||
"""Forward pass for inference.
|
||||
|
||||
B: batch size;
|
||||
D: embedding dimension;
|
||||
R: length of right_context;
|
||||
U: length of utterance;
|
||||
M: length of memory.
|
||||
|
||||
Args:
|
||||
utterance (torch.Tensor):
|
||||
Utterance frames, with shape (U, B, D).
|
||||
lengths (torch.Tensor):
|
||||
With shape (B,) and i-th element representing
|
||||
number of valid frames for i-th batch element in utterance.
|
||||
right_context (torch.Tensor):
|
||||
Right context frames, with shape (R, B, D).
|
||||
memory (torch.Tensor):
|
||||
Memory elements, with shape (M, B, D).
|
||||
state (List[torch.Tensor], optional):
|
||||
List of tensors representing layer internal state generated in
|
||||
preceding computation. (default=None)
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For infer mode, PE = L+2*U-1.
|
||||
conv_cache (torch.Tensor, optional):
|
||||
Cache tensor of left context for causal convolution.
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||
- output utterance, with shape (U, B, D);
|
||||
- output right_context, with shape (R, B, D);
|
||||
- output memory, with shape (1, B, D) or (0, B, D).
|
||||
- output state.
|
||||
- updated conv_cache.
|
||||
"""
|
||||
R = right_context.size(0)
|
||||
src = torch.cat([right_context, utterance])
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# emformer attention module
|
||||
(
|
||||
src_att,
|
||||
output_memory,
|
||||
output_state,
|
||||
) = self._apply_attention_module_infer(
|
||||
src, R, lengths, memory, pos_emb, state
|
||||
)
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
src_conv, conv_cache = self._apply_conv_module_infer(src, R, conv_cache)
|
||||
src = src + self.dropout(src_conv)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
output_utterance = src[R:]
|
||||
output_right_context = src[:R]
|
||||
return (
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
output_memory,
|
||||
output_state,
|
||||
conv_cache,
|
||||
)
|
||||
|
@ -154,9 +154,131 @@ def test_convolution_module_infer():
|
||||
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_layer_forward():
|
||||
from emformer import EmformerEncoderLayer
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 8
|
||||
right_context_length = 2
|
||||
left_context_length = 8
|
||||
kernel_size = 31
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S = num_chunks
|
||||
M = S - 1
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
layer = EmformerEncoderLayer(
|
||||
d_model=D,
|
||||
nhead=8,
|
||||
dim_feedforward=1024,
|
||||
chunk_length=chunk_length,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
Q, KV = R + U + S, M + R + U
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||
PE = 2 * U - 1
|
||||
pos_emb = torch.randn(PE, D)
|
||||
|
||||
output_utterance, output_right_context, output_memory = layer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
)
|
||||
assert output_utterance.shape == (U, B, D)
|
||||
assert output_right_context.shape == (R, B, D)
|
||||
assert output_memory.shape == (M, B, D)
|
||||
|
||||
|
||||
def test_emformer_encoder_layer_infer():
|
||||
from emformer import EmformerEncoderLayer
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 8
|
||||
right_context_length = 2
|
||||
left_context_length = 8
|
||||
kernel_size = 31
|
||||
num_chunks = 1
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
|
||||
layer = EmformerEncoderLayer(
|
||||
d_model=D,
|
||||
nhead=8,
|
||||
dim_feedforward=1024,
|
||||
chunk_length=chunk_length,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
state = None
|
||||
PE = left_context_length + 2 * U - 1
|
||||
pos_emb = torch.randn(PE, D)
|
||||
conv_cache = None
|
||||
(
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
output_memory,
|
||||
output_state,
|
||||
conv_cache,
|
||||
) = layer.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
pos_emb,
|
||||
state,
|
||||
conv_cache,
|
||||
)
|
||||
assert output_utterance.shape == (U, B, D)
|
||||
assert output_right_context.shape == (R, B, D)
|
||||
if use_memory:
|
||||
assert output_memory.shape == (1, B, D)
|
||||
else:
|
||||
assert output_memory.shape == (0, B, D)
|
||||
assert len(output_state) == 4
|
||||
assert output_state[0].shape == (M, B, D)
|
||||
assert output_state[1].shape == (left_context_length, B, D)
|
||||
assert output_state[2].shape == (left_context_length, B, D)
|
||||
assert output_state[3].shape == (1, B)
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rel_positional_encoding()
|
||||
test_emformer_attention_forward()
|
||||
test_emformer_attention_infer()
|
||||
test_convolution_module_forward()
|
||||
test_convolution_module_infer()
|
||||
test_emformer_encoder_layer_forward()
|
||||
test_emformer_encoder_layer_infer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user