mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add EmformerEncoderLayer module
This commit is contained in:
parent
3838b84313
commit
943cb9d5a3
@ -419,7 +419,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
assert cache.shape == (B, D, self.cache_size), cache.shape
|
assert cache.shape == (B, D, self.cache_size), cache.shape
|
||||||
x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R)
|
x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R)
|
||||||
# update cache
|
# update cache
|
||||||
new_cache = x[:, :, -R - self.cache_size:-R]
|
new_cache = x[:, :, -R - self.cache_size : -R]
|
||||||
|
|
||||||
# 1-D depth-wise conv
|
# 1-D depth-wise conv
|
||||||
x = self.depthwise_conv(x) # (B, D, U + R)
|
x = self.depthwise_conv(x) # (B, D, U + R)
|
||||||
@ -572,7 +572,7 @@ class EmformerAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: Input tensor, of shape (B, nhead, U, PE).
|
x: Input tensor, of shape (B, nhead, U, PE).
|
||||||
U is the length of query vector.
|
U is the length of query vector.
|
||||||
For training mode, PE = 2 * U - 1;
|
For training and validation mode, PE = 2 * U - 1;
|
||||||
for inference mode, PE = L + 2 * U - 1.
|
for inference mode, PE = L + 2 * U - 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -666,7 +666,7 @@ class EmformerAttention(nn.Module):
|
|||||||
L = left_context_key.size(0)
|
L = left_context_key.size(0)
|
||||||
assert PE == L + 2 * U - 1
|
assert PE == L + 2 * U - 1
|
||||||
else:
|
else:
|
||||||
# training mode
|
# training and validation mode
|
||||||
assert PE == 2 * U - 1
|
assert PE == 2 * U - 1
|
||||||
pos_emb = (
|
pos_emb = (
|
||||||
self.linear_pos(pos_emb)
|
self.linear_pos(pos_emb)
|
||||||
@ -679,7 +679,7 @@ class EmformerAttention(nn.Module):
|
|||||||
) # (B, nhead, U, PE)
|
) # (B, nhead, U, PE)
|
||||||
# rel-shift operation
|
# rel-shift operation
|
||||||
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
||||||
# (B, nhead, U, U) for training mode;
|
# (B, nhead, U, U) for training and validation mode;
|
||||||
# (B, nhead, U, L + U) for inference mode.
|
# (B, nhead, U, L + U) for inference mode.
|
||||||
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
||||||
B * self.nhead, U, -1
|
B * self.nhead, U, -1
|
||||||
@ -730,7 +730,7 @@ class EmformerAttention(nn.Module):
|
|||||||
pos_emb: torch.Tensor,
|
pos_emb: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# TODO: Modify docs.
|
# TODO: Modify docs.
|
||||||
"""Forward pass for training mode.
|
"""Forward pass for training and validation mode.
|
||||||
|
|
||||||
B: batch size;
|
B: batch size;
|
||||||
D: embedding dimension;
|
D: embedding dimension;
|
||||||
@ -922,3 +922,464 @@ class EmformerAttention(nn.Module):
|
|||||||
key[M + R :],
|
key[M + R :],
|
||||||
value[M + R :],
|
value[M + R :],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmformerEncoderLayer(nn.Module):
|
||||||
|
"""Emformer layer that constitutes Emformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model (int):
|
||||||
|
Input dimension.
|
||||||
|
nhead (int):
|
||||||
|
Number of attention heads.
|
||||||
|
dim_feedforward (int):
|
||||||
|
Hidden layer dimension of feedforward network.
|
||||||
|
chunk_length (int):
|
||||||
|
Length of each input segment.
|
||||||
|
dropout (float, optional):
|
||||||
|
Dropout probability. (Default: 0.0)
|
||||||
|
layer_dropout (float, optional):
|
||||||
|
Layer dropout probability. (Default: 0.0)
|
||||||
|
cnn_module_kernel (int):
|
||||||
|
Kernel size of convolution module.
|
||||||
|
left_context_length (int, optional):
|
||||||
|
Length of left context. (Default: 0)
|
||||||
|
right_context_length (int, optional):
|
||||||
|
Length of right context. (Default: 0)
|
||||||
|
max_memory_size (int, optional):
|
||||||
|
Maximum number of memory elements to use. (Default: 0)
|
||||||
|
tanh_on_mem (bool, optional):
|
||||||
|
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||||
|
negative_inf (float, optional):
|
||||||
|
Value to use for negative infinity in attention weights. (Default: -1e8)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
nhead: int,
|
||||||
|
dim_feedforward: int,
|
||||||
|
chunk_length: int,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
layer_dropout: float = 0.075,
|
||||||
|
cnn_module_kernel: int = 31,
|
||||||
|
left_context_length: int = 0,
|
||||||
|
right_context_length: int = 0,
|
||||||
|
max_memory_size: int = 0,
|
||||||
|
tanh_on_mem: bool = False,
|
||||||
|
negative_inf: float = -1e8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attention = EmformerAttention(
|
||||||
|
embed_dim=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
dropout=dropout,
|
||||||
|
tanh_on_mem=tanh_on_mem,
|
||||||
|
negative_inf=negative_inf,
|
||||||
|
)
|
||||||
|
self.summary_op = nn.AvgPool1d(
|
||||||
|
kernel_size=chunk_length, stride=chunk_length, ceil_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
|
ActivationBalancer(channel_dim=-1),
|
||||||
|
DoubleSwish(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward = nn.Sequential(
|
||||||
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
|
ActivationBalancer(channel_dim=-1),
|
||||||
|
DoubleSwish(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_module = ConvolutionModule(
|
||||||
|
chunk_length,
|
||||||
|
right_context_length,
|
||||||
|
d_model,
|
||||||
|
cnn_module_kernel,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
|
# try to ensure the output is close to zero-mean
|
||||||
|
# (or at least, zero-median).
|
||||||
|
self.balancer = ActivationBalancer(
|
||||||
|
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.layer_dropout = layer_dropout
|
||||||
|
self.left_context_length = left_context_length
|
||||||
|
self.chunk_length = chunk_length
|
||||||
|
self.max_memory_size = max_memory_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.use_memory = max_memory_size > 0
|
||||||
|
|
||||||
|
def _init_state(
|
||||||
|
self, batch_size: int, device: Optional[torch.device]
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Initialize states with zeros."""
|
||||||
|
empty_memory = torch.zeros(
|
||||||
|
self.max_memory_size, batch_size, self.d_model, device=device
|
||||||
|
)
|
||||||
|
left_context_key = torch.zeros(
|
||||||
|
self.left_context_length, batch_size, self.d_model, device=device
|
||||||
|
)
|
||||||
|
left_context_val = torch.zeros(
|
||||||
|
self.left_context_length, batch_size, self.d_model, device=device
|
||||||
|
)
|
||||||
|
past_length = torch.zeros(
|
||||||
|
1, batch_size, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
return [empty_memory, left_context_key, left_context_val, past_length]
|
||||||
|
|
||||||
|
def _unpack_state(
|
||||||
|
self, state: List[torch.Tensor]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Unpack cached states including:
|
||||||
|
1) output memory from previous chunks in the lower layer;
|
||||||
|
2) attention key and value of left context from proceeding chunk's
|
||||||
|
computation.
|
||||||
|
"""
|
||||||
|
past_length = state[3][0][0].item()
|
||||||
|
past_left_context_length = min(self.left_context_length, past_length)
|
||||||
|
past_memory_length = min(
|
||||||
|
self.max_memory_size, math.ceil(past_length / self.chunk_length)
|
||||||
|
)
|
||||||
|
memory_start_idx = self.max_memory_size - past_memory_length
|
||||||
|
pre_memory = state[0][memory_start_idx:]
|
||||||
|
left_context_start_idx = (
|
||||||
|
self.left_context_length - past_left_context_length
|
||||||
|
)
|
||||||
|
left_context_key = state[1][left_context_start_idx:]
|
||||||
|
left_context_val = state[2][left_context_start_idx:]
|
||||||
|
return pre_memory, left_context_key, left_context_val
|
||||||
|
|
||||||
|
def _pack_state(
|
||||||
|
self,
|
||||||
|
next_key: torch.Tensor,
|
||||||
|
next_val: torch.Tensor,
|
||||||
|
update_length: int,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
state: List[torch.Tensor],
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Pack updated states including:
|
||||||
|
1) output memory of current chunk in the lower layer;
|
||||||
|
2) attention key and value in current chunk's computation, which would
|
||||||
|
be resued in next chunk's computation.
|
||||||
|
3) length of current chunk.
|
||||||
|
"""
|
||||||
|
new_memory = torch.cat([state[0], memory])
|
||||||
|
new_key = torch.cat([state[1], next_key])
|
||||||
|
new_val = torch.cat([state[2], next_val])
|
||||||
|
memory_start_idx = new_memory.size(0) - self.max_memory_size
|
||||||
|
state[0] = new_memory[memory_start_idx:]
|
||||||
|
key_start_idx = new_key.size(0) - self.left_context_length
|
||||||
|
state[1] = new_key[key_start_idx:]
|
||||||
|
val_start_idx = new_val.size(0) - self.left_context_length
|
||||||
|
state[2] = new_val[val_start_idx:]
|
||||||
|
state[3] = state[3] + update_length
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _apply_conv_module_forward(
|
||||||
|
self,
|
||||||
|
right_context_utterance: torch.Tensor,
|
||||||
|
R: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply convolution module in training and validation mode."""
|
||||||
|
utterance = right_context_utterance[R:]
|
||||||
|
right_context = right_context_utterance[:R]
|
||||||
|
utterance, right_context, _ = self.conv_module(utterance, right_context)
|
||||||
|
right_context_utterance = torch.cat([right_context, utterance])
|
||||||
|
return right_context_utterance
|
||||||
|
|
||||||
|
def _apply_conv_module_infer(
|
||||||
|
self,
|
||||||
|
right_context_utterance: torch.Tensor,
|
||||||
|
R: int,
|
||||||
|
conv_cache: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply convolution module on utterance in inference mode."""
|
||||||
|
utterance = right_context_utterance[R:]
|
||||||
|
right_context = right_context_utterance[:R]
|
||||||
|
utterance, right_context, conv_cache = self.conv_module.infer(
|
||||||
|
utterance, right_context, conv_cache
|
||||||
|
)
|
||||||
|
right_context_utterance = torch.cat([right_context, utterance])
|
||||||
|
return right_context_utterance, conv_cache
|
||||||
|
|
||||||
|
def _apply_attention_module_forward(
|
||||||
|
self,
|
||||||
|
right_context_utterance: torch.Tensor,
|
||||||
|
R: int,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Apply attention module in training and validation mode."""
|
||||||
|
if attention_mask is None:
|
||||||
|
raise ValueError(
|
||||||
|
"attention_mask must be not None in training or validation mode." # noqa
|
||||||
|
)
|
||||||
|
utterance = right_context_utterance[R:]
|
||||||
|
right_context = right_context_utterance[:R]
|
||||||
|
|
||||||
|
if self.use_memory:
|
||||||
|
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||||
|
2, 0, 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
summary = torch.empty(0).to(
|
||||||
|
dtype=utterance.dtype, device=utterance.device
|
||||||
|
)
|
||||||
|
output_right_context_utterance, output_memory = self.attention(
|
||||||
|
utterance=utterance,
|
||||||
|
lengths=lengths,
|
||||||
|
right_context=right_context,
|
||||||
|
summary=summary,
|
||||||
|
memory=memory,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_right_context_utterance, output_memory
|
||||||
|
|
||||||
|
def _apply_attention_module_infer(
|
||||||
|
self,
|
||||||
|
right_context_utterance: torch.Tensor,
|
||||||
|
R: int,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
state: Optional[List[torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""Apply attention module in inference mode.
|
||||||
|
1) Unpack cached states including:
|
||||||
|
- memory from previous chunks in the lower layer;
|
||||||
|
- attention key and value of left context from proceeding
|
||||||
|
chunk's compuation;
|
||||||
|
2) Apply attention computation;
|
||||||
|
3) Pack updated states including:
|
||||||
|
- output memory of current chunk in the lower layer;
|
||||||
|
- attention key and value in current chunk's computation, which would
|
||||||
|
be resued in next chunk's computation.
|
||||||
|
- length of current chunk.
|
||||||
|
"""
|
||||||
|
utterance = right_context_utterance[R:]
|
||||||
|
right_context = right_context_utterance[:R]
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
state = self._init_state(utterance.size(1), device=utterance.device)
|
||||||
|
pre_memory, left_context_key, left_context_val = self._unpack_state(
|
||||||
|
state
|
||||||
|
)
|
||||||
|
if self.use_memory:
|
||||||
|
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||||
|
2, 0, 1
|
||||||
|
)
|
||||||
|
summary = summary[:1]
|
||||||
|
else:
|
||||||
|
summary = torch.empty(0).to(
|
||||||
|
dtype=utterance.dtype, device=utterance.device
|
||||||
|
)
|
||||||
|
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
|
||||||
|
# for query of [utterance] (i), key-value [left_context, utterance] (j),
|
||||||
|
# the max relative distance i - j is L + U - 1
|
||||||
|
# the min relative distance i - j is -(U - 1)
|
||||||
|
L = left_context_key.size(0) # L <= left_context_length
|
||||||
|
U = utterance.size(0)
|
||||||
|
PE = L + 2 * U - 1
|
||||||
|
tot_PE = self.left_context_length + 2 * U - 1
|
||||||
|
assert pos_emb.size(0) == tot_PE
|
||||||
|
pos_emb = pos_emb[tot_PE - PE :]
|
||||||
|
(
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory,
|
||||||
|
next_key,
|
||||||
|
next_val,
|
||||||
|
) = self.attention.infer(
|
||||||
|
utterance=utterance,
|
||||||
|
lengths=lengths,
|
||||||
|
right_context=right_context,
|
||||||
|
summary=summary,
|
||||||
|
memory=pre_memory,
|
||||||
|
left_context_key=left_context_key,
|
||||||
|
left_context_val=left_context_val,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
)
|
||||||
|
state = self._pack_state(
|
||||||
|
next_key, next_val, utterance.size(0), memory, state
|
||||||
|
)
|
||||||
|
return output_right_context_utterance, output_memory, state
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
utterance: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
right_context: torch.Tensor,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
r"""Forward pass for training and validation mode.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
D: embedding dimension;
|
||||||
|
R: length of hard-copied right contexts;
|
||||||
|
U: length of full utterance;
|
||||||
|
M: length of memory vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
utterance (torch.Tensor):
|
||||||
|
Utterance frames, with shape (U, B, D).
|
||||||
|
lengths (torch.Tensor):
|
||||||
|
With shape (B,) and i-th element representing
|
||||||
|
number of valid frames for i-th batch element in utterance.
|
||||||
|
right_context (torch.Tensor):
|
||||||
|
Right context frames, with shape (R, B, D).
|
||||||
|
memory (torch.Tensor):
|
||||||
|
Memory elements, with shape (M, B, D).
|
||||||
|
It is an empty tensor without using memory.
|
||||||
|
attention_mask (torch.Tensor):
|
||||||
|
Attention mask for underlying attention module,
|
||||||
|
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||||
|
pos_emb (torch.Tensor):
|
||||||
|
Position encoding embedding, with shape (PE, D).
|
||||||
|
For training mode, P = 2*U-1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing 3 tensors:
|
||||||
|
- output utterance, with shape (U, B, D).
|
||||||
|
- output right context, with shape (R, B, D).
|
||||||
|
- output memory, with shape (M, B, D).
|
||||||
|
"""
|
||||||
|
R = right_context.size(0)
|
||||||
|
src = torch.cat([right_context, utterance])
|
||||||
|
src_orig = src
|
||||||
|
|
||||||
|
warmup_scale = min(0.1 + warmup, 1.0)
|
||||||
|
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||||
|
# completely bypass it.
|
||||||
|
if self.training:
|
||||||
|
alpha = (
|
||||||
|
warmup_scale
|
||||||
|
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||||
|
else 0.1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
# macaron style feed forward module
|
||||||
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
|
# emformer attention module
|
||||||
|
src_att, output_memory = self._apply_attention_module_forward(
|
||||||
|
src, R, lengths, memory, pos_emb, attention_mask
|
||||||
|
)
|
||||||
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
src_conv = self._apply_conv_module_forward(src, R)
|
||||||
|
src = src + self.dropout(src_conv)
|
||||||
|
|
||||||
|
# feed forward module
|
||||||
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
|
||||||
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
|
if alpha != 1.0:
|
||||||
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
|
output_utterance = src[R:]
|
||||||
|
output_right_context = src[:R]
|
||||||
|
return output_utterance, output_right_context, output_memory
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def infer(
|
||||||
|
self,
|
||||||
|
utterance: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
right_context: torch.Tensor,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
state: Optional[List[torch.Tensor]] = None,
|
||||||
|
conv_cache: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||||
|
"""Forward pass for inference.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
D: embedding dimension;
|
||||||
|
R: length of right_context;
|
||||||
|
U: length of utterance;
|
||||||
|
M: length of memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
utterance (torch.Tensor):
|
||||||
|
Utterance frames, with shape (U, B, D).
|
||||||
|
lengths (torch.Tensor):
|
||||||
|
With shape (B,) and i-th element representing
|
||||||
|
number of valid frames for i-th batch element in utterance.
|
||||||
|
right_context (torch.Tensor):
|
||||||
|
Right context frames, with shape (R, B, D).
|
||||||
|
memory (torch.Tensor):
|
||||||
|
Memory elements, with shape (M, B, D).
|
||||||
|
state (List[torch.Tensor], optional):
|
||||||
|
List of tensors representing layer internal state generated in
|
||||||
|
preceding computation. (default=None)
|
||||||
|
pos_emb (torch.Tensor):
|
||||||
|
Position encoding embedding, with shape (PE, D).
|
||||||
|
For infer mode, PE = L+2*U-1.
|
||||||
|
conv_cache (torch.Tensor, optional):
|
||||||
|
Cache tensor of left context for causal convolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||||
|
- output utterance, with shape (U, B, D);
|
||||||
|
- output right_context, with shape (R, B, D);
|
||||||
|
- output memory, with shape (1, B, D) or (0, B, D).
|
||||||
|
- output state.
|
||||||
|
- updated conv_cache.
|
||||||
|
"""
|
||||||
|
R = right_context.size(0)
|
||||||
|
src = torch.cat([right_context, utterance])
|
||||||
|
|
||||||
|
# macaron style feed forward module
|
||||||
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
|
# emformer attention module
|
||||||
|
(
|
||||||
|
src_att,
|
||||||
|
output_memory,
|
||||||
|
output_state,
|
||||||
|
) = self._apply_attention_module_infer(
|
||||||
|
src, R, lengths, memory, pos_emb, state
|
||||||
|
)
|
||||||
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
src_conv, conv_cache = self._apply_conv_module_infer(src, R, conv_cache)
|
||||||
|
src = src + self.dropout(src_conv)
|
||||||
|
|
||||||
|
# feed forward module
|
||||||
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
|
||||||
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
|
output_utterance = src[R:]
|
||||||
|
output_right_context = src[:R]
|
||||||
|
return (
|
||||||
|
output_utterance,
|
||||||
|
output_right_context,
|
||||||
|
output_memory,
|
||||||
|
output_state,
|
||||||
|
conv_cache,
|
||||||
|
)
|
||||||
|
@ -154,9 +154,131 @@ def test_convolution_module_infer():
|
|||||||
assert new_cache.shape == (B, D, kernel_size - 1)
|
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_layer_forward():
|
||||||
|
from emformer import EmformerEncoderLayer
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
chunk_length = 8
|
||||||
|
right_context_length = 2
|
||||||
|
left_context_length = 8
|
||||||
|
kernel_size = 31
|
||||||
|
num_chunks = 3
|
||||||
|
U = num_chunks * chunk_length
|
||||||
|
R = num_chunks * right_context_length
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
S = num_chunks
|
||||||
|
M = S - 1
|
||||||
|
else:
|
||||||
|
S, M = 0, 0
|
||||||
|
|
||||||
|
layer = EmformerEncoderLayer(
|
||||||
|
d_model=D,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
|
||||||
|
Q, KV = R + U + S, M + R + U
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
|
lengths[0] = U
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
memory = torch.randn(M, B, D)
|
||||||
|
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||||
|
PE = 2 * U - 1
|
||||||
|
pos_emb = torch.randn(PE, D)
|
||||||
|
|
||||||
|
output_utterance, output_right_context, output_memory = layer(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
pos_emb,
|
||||||
|
)
|
||||||
|
assert output_utterance.shape == (U, B, D)
|
||||||
|
assert output_right_context.shape == (R, B, D)
|
||||||
|
assert output_memory.shape == (M, B, D)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_layer_infer():
|
||||||
|
from emformer import EmformerEncoderLayer
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
chunk_length = 8
|
||||||
|
right_context_length = 2
|
||||||
|
left_context_length = 8
|
||||||
|
kernel_size = 31
|
||||||
|
num_chunks = 1
|
||||||
|
U = num_chunks * chunk_length
|
||||||
|
R = num_chunks * right_context_length
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
|
||||||
|
layer = EmformerEncoderLayer(
|
||||||
|
d_model=D,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
|
lengths[0] = U
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
memory = torch.randn(M, B, D)
|
||||||
|
state = None
|
||||||
|
PE = left_context_length + 2 * U - 1
|
||||||
|
pos_emb = torch.randn(PE, D)
|
||||||
|
conv_cache = None
|
||||||
|
(
|
||||||
|
output_utterance,
|
||||||
|
output_right_context,
|
||||||
|
output_memory,
|
||||||
|
output_state,
|
||||||
|
conv_cache,
|
||||||
|
) = layer.infer(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
pos_emb,
|
||||||
|
state,
|
||||||
|
conv_cache,
|
||||||
|
)
|
||||||
|
assert output_utterance.shape == (U, B, D)
|
||||||
|
assert output_right_context.shape == (R, B, D)
|
||||||
|
if use_memory:
|
||||||
|
assert output_memory.shape == (1, B, D)
|
||||||
|
else:
|
||||||
|
assert output_memory.shape == (0, B, D)
|
||||||
|
assert len(output_state) == 4
|
||||||
|
assert output_state[0].shape == (M, B, D)
|
||||||
|
assert output_state[1].shape == (left_context_length, B, D)
|
||||||
|
assert output_state[2].shape == (left_context_length, B, D)
|
||||||
|
assert output_state[3].shape == (1, B)
|
||||||
|
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_rel_positional_encoding()
|
test_rel_positional_encoding()
|
||||||
test_emformer_attention_forward()
|
test_emformer_attention_forward()
|
||||||
test_emformer_attention_infer()
|
test_emformer_attention_infer()
|
||||||
test_convolution_module_forward()
|
test_convolution_module_forward()
|
||||||
test_convolution_module_infer()
|
test_convolution_module_infer()
|
||||||
|
test_emformer_encoder_layer_forward()
|
||||||
|
test_emformer_encoder_layer_infer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user