mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add emformer attention module
This commit is contained in:
parent
a0c7095e42
commit
5dc5f8305a
@ -151,3 +151,485 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.gen_pe_negative()
|
self.gen_pe_negative()
|
||||||
pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype)
|
pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype)
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class EmformerAttention(nn.Module):
|
||||||
|
r"""Emformer layer attention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim (int):
|
||||||
|
Embedding dimension.
|
||||||
|
nhead (int):
|
||||||
|
Number of attention heads in each Emformer layer.
|
||||||
|
dropout (float, optional):
|
||||||
|
Dropout probability. (Default: 0.0)
|
||||||
|
tanh_on_mem (bool, optional):
|
||||||
|
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||||
|
negative_inf (float, optional):
|
||||||
|
Value to use for negative infinity in attention weights. (Default: -1e8)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
nhead: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
tanh_on_mem: bool = False,
|
||||||
|
negative_inf: float = -1e8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if embed_dim % nhead != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim ({embed_dim}) is not a multiple of"
|
||||||
|
f"nhead ({nhead})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.nhead = nhead
|
||||||
|
self.tanh_on_mem = tanh_on_mem
|
||||||
|
self.negative_inf = negative_inf
|
||||||
|
self.head_dim = embed_dim // nhead
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.emb_to_key_value = ScaledLinear(
|
||||||
|
embed_dim, 2 * embed_dim, bias=True
|
||||||
|
)
|
||||||
|
self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True)
|
||||||
|
self.out_proj = ScaledLinear(
|
||||||
|
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear transformation for positional encoding.
|
||||||
|
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||||
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
|
||||||
|
self.pos_bias_u = nn.Parameter(torch.Tensor(nhead, self.head_dim))
|
||||||
|
self.pos_bias_v = nn.Parameter(torch.Tensor(nhead, self.head_dim))
|
||||||
|
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
||||||
|
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _pos_bias_u(self):
|
||||||
|
return self.pos_bias_u * self.pos_bias_u_scale.exp()
|
||||||
|
|
||||||
|
def _pos_bias_v(self):
|
||||||
|
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||||
|
|
||||||
|
def _reset_parameters(self) -> None:
|
||||||
|
nn.init.normal_(self.pos_bias_u, std=0.01)
|
||||||
|
nn.init.normal_(self.pos_bias_v, std=0.01)
|
||||||
|
|
||||||
|
def _gen_attention_probs(
|
||||||
|
self,
|
||||||
|
attention_weights: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Given the entire attention weights, mask out unecessary connections
|
||||||
|
and optionally with padding positions, to obtain underlying chunk-wise
|
||||||
|
attention probabilities.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
Q: length of query;
|
||||||
|
KV: length of key and value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_weights (torch.Tensor):
|
||||||
|
Attention weights computed on the entire concatenated tensor
|
||||||
|
with shape (B * nhead, Q, KV).
|
||||||
|
attention_mask (torch.Tensor):
|
||||||
|
Mask tensor where chunk-wise connections are filled with `False`,
|
||||||
|
and other unnecessary connections are filled with `True`,
|
||||||
|
with shape (Q, KV).
|
||||||
|
padding_mask (torch.Tensor, optional):
|
||||||
|
Mask tensor where the padding positions are fill with `True`,
|
||||||
|
and other positions are filled with `False`, with shapa `(B, KV)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (B * nhead, Q, KV).
|
||||||
|
"""
|
||||||
|
attention_weights_float = attention_weights.float()
|
||||||
|
attention_weights_float = attention_weights_float.masked_fill(
|
||||||
|
attention_mask.unsqueeze(0), self.negative_inf
|
||||||
|
)
|
||||||
|
if padding_mask is not None:
|
||||||
|
Q = attention_weights.size(1)
|
||||||
|
B = attention_weights.size(0) // self.nhead
|
||||||
|
attention_weights_float = attention_weights_float.view(
|
||||||
|
B, self.nhead, Q, -1
|
||||||
|
)
|
||||||
|
attention_weights_float = attention_weights_float.masked_fill(
|
||||||
|
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||||
|
self.negative_inf,
|
||||||
|
)
|
||||||
|
attention_weights_float = attention_weights_float.view(
|
||||||
|
B * self.nhead, Q, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_probs = nn.functional.softmax(
|
||||||
|
attention_weights_float, dim=-1
|
||||||
|
).type_as(attention_weights)
|
||||||
|
|
||||||
|
attention_probs = nn.functional.dropout(
|
||||||
|
attention_probs, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
return attention_probs
|
||||||
|
|
||||||
|
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor, of shape (B, nhead, U, PE).
|
||||||
|
U is the length of query vector.
|
||||||
|
For training mode, PE = 2 * U - 1;
|
||||||
|
for inference mode, PE = L + 2 * U - 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (B, nhead, U, out_len).
|
||||||
|
For non-infer mode, out_len = U;
|
||||||
|
for infer mode, out_len = L + U.
|
||||||
|
"""
|
||||||
|
B, nhead, U, PE = x.size()
|
||||||
|
B_stride = x.stride(0)
|
||||||
|
nhead_stride = x.stride(1)
|
||||||
|
U_stride = x.stride(2)
|
||||||
|
PE_stride = x.stride(3)
|
||||||
|
out_len = PE - (U - 1)
|
||||||
|
return x.as_strided(
|
||||||
|
size=(B, nhead, U, out_len),
|
||||||
|
stride=(B_stride, nhead_stride, U_stride - PE_stride, PE_stride),
|
||||||
|
storage_offset=PE_stride * (U - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward_impl(
|
||||||
|
self,
|
||||||
|
utterance: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
right_context: torch.Tensor,
|
||||||
|
summary: torch.Tensor,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
left_context_key: Optional[torch.Tensor] = None,
|
||||||
|
left_context_val: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Underlying chunk-wise attention implementation."""
|
||||||
|
U, B, _ = utterance.size()
|
||||||
|
R = right_context.size(0)
|
||||||
|
M = memory.size(0)
|
||||||
|
scaling = float(self.head_dim) ** -0.5
|
||||||
|
|
||||||
|
# compute query with [right_context, utterance, summary].
|
||||||
|
query = self.emb_to_query(
|
||||||
|
torch.cat([right_context, utterance, summary])
|
||||||
|
)
|
||||||
|
# compute key and value with [memory, right_context, utterance].
|
||||||
|
key, value = self.emb_to_key_value(
|
||||||
|
torch.cat([memory, right_context, utterance])
|
||||||
|
).chunk(chunks=2, dim=2)
|
||||||
|
|
||||||
|
if left_context_key is not None and left_context_val is not None:
|
||||||
|
# now compute key and value with
|
||||||
|
# [memory, right context, left context, uttrance]
|
||||||
|
# this is used in inference mode
|
||||||
|
key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
|
||||||
|
value = torch.cat(
|
||||||
|
[value[: M + R], left_context_val, value[M + R :]]
|
||||||
|
)
|
||||||
|
Q = query.size(0)
|
||||||
|
KV = key.size(0)
|
||||||
|
|
||||||
|
reshaped_key, reshaped_value = [
|
||||||
|
tensor.contiguous()
|
||||||
|
.view(KV, B * self.nhead, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
for tensor in [key, value]
|
||||||
|
] # both of shape (B * nhead, KV, head_dim)
|
||||||
|
reshaped_query = (
|
||||||
|
query.contiguous().view(Q, B, self.nhead, self.head_dim) * scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute attention score
|
||||||
|
# first, compute attention matrix a and matrix c
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
|
||||||
|
query_with_bais_u = (
|
||||||
|
(reshaped_query + self._pos_bias_u())
|
||||||
|
.view(Q, B * self.nhead, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
) # (B * nhead, Q, head_dim)
|
||||||
|
matrix_ac = torch.bmm(
|
||||||
|
query_with_bais_u, reshaped_key.transpose(1, 2)
|
||||||
|
) # (B * nhead, Q, KV)
|
||||||
|
|
||||||
|
# second, compute attention matrix b and matrix d
|
||||||
|
# relative positional encoding is applied on the part of attention
|
||||||
|
# between chunk (in query) and itself as well as its left context
|
||||||
|
# (in key)
|
||||||
|
utterance_with_bais_v = (
|
||||||
|
reshaped_query[R : R + U] + self._pos_bias_v()
|
||||||
|
).permute(1, 2, 0, 3)
|
||||||
|
# (B, nhead, U, head_dim)
|
||||||
|
PE = pos_emb.size(0)
|
||||||
|
if left_context_key is not None and left_context_val is not None:
|
||||||
|
# inference mode
|
||||||
|
L = left_context_key.size(0)
|
||||||
|
assert PE == L + 2 * U - 1
|
||||||
|
else:
|
||||||
|
# training mode
|
||||||
|
assert PE == 2 * U - 1
|
||||||
|
pos_emb = (
|
||||||
|
self.linear_pos(pos_emb)
|
||||||
|
.view(PE, self.nhead, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
.unsqueeze(0)
|
||||||
|
) # (1, nhead, PE, head_dim)
|
||||||
|
matrix_bd_utterance = torch.matmul(
|
||||||
|
utterance_with_bais_v, pos_emb.transpose(-2, -1)
|
||||||
|
) # (B, nhead, U, PE)
|
||||||
|
# rel-shift operation
|
||||||
|
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
||||||
|
# (B, nhead, U, U) for training mode;
|
||||||
|
# (B, nhead, U, L + U) for inference mode.
|
||||||
|
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
||||||
|
B * self.nhead, U, -1
|
||||||
|
)
|
||||||
|
matrix_bd = torch.zeros_like(matrix_ac)
|
||||||
|
matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance
|
||||||
|
|
||||||
|
attention_weights = matrix_ac + matrix_bd
|
||||||
|
|
||||||
|
# compute padding mask
|
||||||
|
if B == 1:
|
||||||
|
padding_mask = None
|
||||||
|
else:
|
||||||
|
padding_mask = make_pad_mask(KV - U + lengths)
|
||||||
|
|
||||||
|
# compute attention probabilities
|
||||||
|
attention_probs = self._gen_attention_probs(
|
||||||
|
attention_weights, attention_mask, padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute attention outputs
|
||||||
|
attention = torch.bmm(attention_probs, reshaped_value)
|
||||||
|
assert attention.shape == (B * self.nhead, Q, self.head_dim)
|
||||||
|
attention = (
|
||||||
|
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply output projection
|
||||||
|
outputs = self.out_proj(attention)
|
||||||
|
|
||||||
|
output_right_context_utterance = outputs[: R + U]
|
||||||
|
output_memory = outputs[R + U :]
|
||||||
|
if self.tanh_on_mem:
|
||||||
|
output_memory = torch.tanh(output_memory)
|
||||||
|
else:
|
||||||
|
output_memory = torch.clamp(output_memory, min=-10, max=10)
|
||||||
|
|
||||||
|
return output_right_context_utterance, output_memory, key, value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
utterance: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
right_context: torch.Tensor,
|
||||||
|
summary: torch.Tensor,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# TODO: Modify docs.
|
||||||
|
"""Forward pass for training mode.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
D: embedding dimension;
|
||||||
|
R: length of the hard-copied right contexts;
|
||||||
|
U: length of full utterance;
|
||||||
|
S: length of summary vectors;
|
||||||
|
M: length of memory vectors.
|
||||||
|
|
||||||
|
It computes a `big` attention matrix on full utterance and
|
||||||
|
then utilizes a pre-computed mask to simulate chunk-wise attention.
|
||||||
|
|
||||||
|
It concatenates three blocks: hard-copied right contexts,
|
||||||
|
full utterance, and summary vectors, as a `big` block,
|
||||||
|
to compute the query tensor:
|
||||||
|
query = [right_context, utterance, summary],
|
||||||
|
with length Q = R + U + S.
|
||||||
|
It concatenates the three blocks: memory vectors,
|
||||||
|
hard-copied right contexts, and full utterance as another `big` block,
|
||||||
|
to compute the key and value tensors:
|
||||||
|
key & value = [memory, right_context, utterance],
|
||||||
|
with length KV = M + R + U.
|
||||||
|
Attention scores is computed with above `big` query and key.
|
||||||
|
|
||||||
|
Then the underlying chunk-wise attention is obtained by applying
|
||||||
|
the attention mask. Suppose
|
||||||
|
c_i: chunk at index i;
|
||||||
|
r_i: right context that c_i can use;
|
||||||
|
l_i: left context that c_i can use;
|
||||||
|
m_i: past memory vectors from previous layer that c_i can use;
|
||||||
|
s_i: summary vector of c_i;
|
||||||
|
The target chunk-wise attention is:
|
||||||
|
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
|
||||||
|
s_i (in query) -> l_i, c_i, r_i (in key).
|
||||||
|
|
||||||
|
Relative positional encoding is applied on the part of attention between
|
||||||
|
utterance (in query) and utterance (in key). Actually, it is applied on
|
||||||
|
the part of attention between each chunk (in query) and itself as well
|
||||||
|
as its left context (in key), after applying the mask:
|
||||||
|
c_i -> l_i, c_i.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
utterance (torch.Tensor):
|
||||||
|
Full utterance frames, with shape (U, B, D).
|
||||||
|
lengths (torch.Tensor):
|
||||||
|
With shape (B,) and i-th element representing
|
||||||
|
number of valid frames for i-th batch element in utterance.
|
||||||
|
right_context (torch.Tensor):
|
||||||
|
Hard-copied right context frames, with shape (R, B, D),
|
||||||
|
where R = num_chunks * right_context_length
|
||||||
|
summary (torch.Tensor):
|
||||||
|
Summary elements with shape (S, B, D), where S = num_chunks.
|
||||||
|
It is an empty tensor without using memory.
|
||||||
|
memory (torch.Tensor):
|
||||||
|
Memory elements, with shape (M, B, D), where M = num_chunks - 1.
|
||||||
|
It is an empty tensor without using memory.
|
||||||
|
attention_mask (torch.Tensor):
|
||||||
|
Pre-computed attention mask to simulate underlying chunk-wise
|
||||||
|
attention, with shape (Q, KV).
|
||||||
|
pos_emb (torch.Tensor):
|
||||||
|
Position encoding embedding, with shape (PE, D).
|
||||||
|
where PE = 2 * U - 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing 2 tensors:
|
||||||
|
- output of right context and utterance, with shape (R + U, B, D).
|
||||||
|
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
||||||
|
"""
|
||||||
|
(
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self._forward_impl(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
summary,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
pos_emb,
|
||||||
|
)
|
||||||
|
return output_right_context_utterance, output_memory[:-1]
|
||||||
|
|
||||||
|
def infer(
|
||||||
|
self,
|
||||||
|
utterance: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
right_context: torch.Tensor,
|
||||||
|
summary: torch.Tensor,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
left_context_key: torch.Tensor,
|
||||||
|
left_context_val: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward pass for inference.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
D: embedding dimension;
|
||||||
|
R: length of right context;
|
||||||
|
U: length of utterance, i.e., current chunk;
|
||||||
|
L: length of cached left context;
|
||||||
|
S: length of summary vectors, S = 1;
|
||||||
|
M: length of cached memory vectors.
|
||||||
|
|
||||||
|
It concatenates the right context, utterance (i.e., current chunk)
|
||||||
|
and summary vector of current chunk, to compute the query tensor:
|
||||||
|
query = [right_context, utterance, summary],
|
||||||
|
with length Q = R + U + S.
|
||||||
|
It concatenates the memory vectors, right context, left context, and
|
||||||
|
current chunk, to compute the key and value tensors:
|
||||||
|
key & value = [memory, right_context, left_context, utterance],
|
||||||
|
with length KV = M + R + L + U.
|
||||||
|
|
||||||
|
The chunk-wise attention is:
|
||||||
|
chunk, right context (in query) ->
|
||||||
|
left context, chunk, right context, memory vectors (in key);
|
||||||
|
summary (in query) -> left context, chunk, right context (in key).
|
||||||
|
|
||||||
|
Relative positional encoding is applied on the part of attention between
|
||||||
|
chunk (in query) and chunk itself as well as its left context (in key):
|
||||||
|
chunk (in query) -> left context, chunk (in key).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
utterance (torch.Tensor):
|
||||||
|
Current chunk frames, with shape (U, B, D), where U = chunk_length.
|
||||||
|
lengths (torch.Tensor):
|
||||||
|
With shape (B,) and i-th element representing
|
||||||
|
number of valid frames for i-th batch element in utterance.
|
||||||
|
right_context (torch.Tensor):
|
||||||
|
Right context frames, with shape (R, B, D),
|
||||||
|
where R = right_context_length.
|
||||||
|
summary (torch.Tensor):
|
||||||
|
Summary vector with shape (1, B, D), or empty tensor.
|
||||||
|
memory (torch.Tensor):
|
||||||
|
Memory vectors, with shape (M, B, D), or empty tensor.
|
||||||
|
left_context_key (torch,Tensor):
|
||||||
|
Cached attention key of left context from preceding computation,
|
||||||
|
with shape (L, B, D), where L <= left_context_length.
|
||||||
|
left_context_val (torch.Tensor):
|
||||||
|
Cached attention value of left context from preceding computation,
|
||||||
|
with shape (L, B, D), where L <= left_context_length.
|
||||||
|
pos_emb (torch.Tensor):
|
||||||
|
Position encoding embedding, with shape (PE, D),
|
||||||
|
where PE = L + 2 * U - 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing 4 tensors:
|
||||||
|
- output of right context and utterance, with shape (R + U, B, D).
|
||||||
|
- memory output, with shape (1, B, D) or (0, B, D).
|
||||||
|
- attention key of left context and utterance, which would be cached
|
||||||
|
for next computation, with shape (L + U, B, D).
|
||||||
|
- attention value of left context and utterance, which would be
|
||||||
|
cached for next computation, with shape (L + U, B, D).
|
||||||
|
"""
|
||||||
|
U = utterance.size(0)
|
||||||
|
R = right_context.size(0)
|
||||||
|
L = left_context_key.size(0)
|
||||||
|
S = summary.size(0)
|
||||||
|
M = memory.size(0)
|
||||||
|
|
||||||
|
# query = [right context, utterance, summary]
|
||||||
|
Q = R + U + S
|
||||||
|
# key, value = [memory, right context, left context, uttrance]
|
||||||
|
KV = M + R + L + U
|
||||||
|
attention_mask = torch.zeros(Q, KV).to(
|
||||||
|
dtype=torch.bool, device=utterance.device
|
||||||
|
)
|
||||||
|
# disallow attention bettween the summary vector with the memory bank
|
||||||
|
attention_mask[-1, :M] = True
|
||||||
|
(
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
) = self._forward_impl(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
summary,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
pos_emb,
|
||||||
|
left_context_key=left_context_key,
|
||||||
|
left_context_val=left_context_val,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory,
|
||||||
|
key[M + R :],
|
||||||
|
value[M + R :],
|
||||||
|
)
|
||||||
|
@ -13,5 +13,96 @@ def test_rel_positional_encoding():
|
|||||||
assert pos_emb.shape == (pos_len + neg_len - 1, D)
|
assert pos_emb.shape == (pos_len + neg_len - 1, D)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_attention_forward():
|
||||||
|
from emformer import EmformerAttention
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
chunk_length = 4
|
||||||
|
right_context_length = 2
|
||||||
|
num_chunks = 3
|
||||||
|
U = num_chunks * chunk_length
|
||||||
|
R = num_chunks * right_context_length
|
||||||
|
attention = EmformerAttention(embed_dim=D, nhead=8)
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
S = num_chunks
|
||||||
|
M = S - 1
|
||||||
|
else:
|
||||||
|
S, M = 0, 0
|
||||||
|
|
||||||
|
Q, KV = R + U + S, M + R + U
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
|
lengths[0] = U
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
summary = torch.randn(S, B, D)
|
||||||
|
memory = torch.randn(M, B, D)
|
||||||
|
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||||
|
PE = 2 * U - 1
|
||||||
|
pos_emb = torch.randn(PE, D)
|
||||||
|
|
||||||
|
output_right_context_utterance, output_memory = attention(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
summary,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
pos_emb,
|
||||||
|
)
|
||||||
|
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||||
|
assert output_memory.shape == (M, B, D)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_attention_infer():
|
||||||
|
from emformer import EmformerAttention
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
U = 4
|
||||||
|
R = 2
|
||||||
|
L = 3
|
||||||
|
attention = EmformerAttention(embed_dim=D, nhead=8)
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
S, M = 1, 3
|
||||||
|
else:
|
||||||
|
S, M = 0, 0
|
||||||
|
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
|
lengths[0] = U
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
summary = torch.randn(S, B, D)
|
||||||
|
memory = torch.randn(M, B, D)
|
||||||
|
left_context_key = torch.randn(L, B, D)
|
||||||
|
left_context_val = torch.randn(L, B, D)
|
||||||
|
PE = L + 2 * U - 1
|
||||||
|
pos_emb = torch.randn(PE, D)
|
||||||
|
|
||||||
|
(
|
||||||
|
output_right_context_utterance,
|
||||||
|
output_memory,
|
||||||
|
next_key,
|
||||||
|
next_val,
|
||||||
|
) = attention.infer(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
summary,
|
||||||
|
memory,
|
||||||
|
left_context_key,
|
||||||
|
left_context_val,
|
||||||
|
pos_emb,
|
||||||
|
)
|
||||||
|
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||||
|
assert output_memory.shape == (S, B, D)
|
||||||
|
assert next_key.shape == (L + U, B, D)
|
||||||
|
assert next_val.shape == (L + U, B, D)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_rel_positional_encoding()
|
test_rel_positional_encoding()
|
||||||
|
test_emformer_attention_forward()
|
||||||
|
test_emformer_attention_infer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user