use average value as memory vector for each chunk

This commit is contained in:
yaozengwei 2022-06-13 22:14:24 +08:00
parent 1c067e7364
commit 193b44ed7a

View File

@ -537,7 +537,6 @@ class EmformerAttention(nn.Module):
self,
utterance: torch.Tensor,
right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor,
attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
@ -550,10 +549,8 @@ class EmformerAttention(nn.Module):
M = memory.size(0)
scaling = float(self.head_dim) ** -0.5
# compute query with [right_context, utterance, summary].
query = self.emb_to_query(
torch.cat([right_context, utterance, summary])
)
# compute query with [right_context, utterance].
query = self.emb_to_query(torch.cat([right_context, utterance]))
# compute key and value with [memory, right_context, utterance].
key, value = self.emb_to_key_value(
torch.cat([memory, right_context, utterance])
@ -593,26 +590,18 @@ class EmformerAttention(nn.Module):
)
# apply output projection
outputs = self.out_proj(attention)
output_right_context_utterance = self.out_proj(attention)
output_right_context_utterance = outputs[: R + U]
output_memory = outputs[R + U :]
if self.tanh_on_mem:
output_memory = torch.tanh(output_memory)
else:
output_memory = torch.clamp(output_memory, min=-10, max=10)
return output_right_context_utterance, output_memory, key, value
return output_right_context_utterance, key, value
def forward(
self,
utterance: torch.Tensor,
right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor,
attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
# TODO: Modify docs.
"""Forward pass for training and validation mode.
@ -620,17 +609,16 @@ class EmformerAttention(nn.Module):
D: embedding dimension;
R: length of the hard-copied right contexts;
U: length of full utterance;
S: length of summary vectors;
M: length of memory vectors.
It computes a `big` attention matrix on full utterance and
then utilizes a pre-computed mask to simulate chunk-wise attention.
It concatenates three blocks: hard-copied right contexts,
full utterance, and summary vectors, as a `big` block,
and full utterance, as a `big` block,
to compute the query tensor:
query = [right_context, utterance, summary],
with length Q = R + U + S.
query = [right_context, utterance],
with length Q = R + U.
It concatenates the three blocks: memory vectors,
hard-copied right contexts, and full utterance as another `big` block,
to compute the key and value tensors:
@ -644,10 +632,8 @@ class EmformerAttention(nn.Module):
r_i: right context that c_i can use;
l_i: left context that c_i can use;
m_i: past memory vectors from previous layer that c_i can use;
s_i: summary vector of c_i;
The target chunk-wise attention is:
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
s_i (in query) -> l_i, c_i, r_i (in key).
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key)
Args:
utterance (torch.Tensor):
@ -655,9 +641,6 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor):
Hard-copied right context frames, with shape (R, B, D),
where R = num_chunks * right_context_length
summary (torch.Tensor):
Summary elements with shape (S, B, D), where S = num_chunks.
It is an empty tensor without using memory.
memory (torch.Tensor):
Memory elements, with shape (M, B, D), where M = num_chunks - 1.
It is an empty tensor without using memory.
@ -668,31 +651,22 @@ class EmformerAttention(nn.Module):
Padding mask of key tensor, with shape (B, KV).
Returns:
A tuple containing 2 tensors:
- output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
Output of right context and utterance, with shape (R + U, B, D).
"""
(
output_right_context_utterance,
output_memory,
_,
_,
) = self._forward_impl(
output_right_context_utterance, _, _ = self._forward_impl(
utterance,
right_context,
summary,
memory,
attention_mask,
padding_mask=padding_mask,
)
return output_right_context_utterance, output_memory[:-1]
return output_right_context_utterance
@torch.jit.export
def infer(
self,
utterance: torch.Tensor,
right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor,
left_context_key: torch.Tensor,
left_context_val: torch.Tensor,
@ -705,13 +679,12 @@ class EmformerAttention(nn.Module):
R: length of right context;
U: length of utterance, i.e., current chunk;
L: length of cached left context;
S: length of summary vectors, S = 1;
M: length of cached memory vectors.
It concatenates the right context, utterance (i.e., current chunk)
and summary vector of current chunk, to compute the query tensor:
query = [right_context, utterance, summary],
with length Q = R + U + S.
It concatenates the right context and utterance (i.e., current chunk)
of current chunk, to compute the query tensor:
query = [right_context, utterance],
with length Q = R + U.
It concatenates the memory vectors, right context, left context, and
current chunk, to compute the key and value tensors:
key & value = [memory, right_context, left_context, utterance],
@ -719,8 +692,7 @@ class EmformerAttention(nn.Module):
The chunk-wise attention is:
chunk, right context (in query) ->
left context, chunk, right context, memory vectors (in key);
summary (in query) -> left context, chunk, right context (in key).
left context, chunk, right context, memory vectors (in key).
Args:
utterance (torch.Tensor):
@ -728,8 +700,6 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor):
Right context frames, with shape (R, B, D),
where R = right_context_length.
summary (torch.Tensor):
Summary vector with shape (1, B, D), or empty tensor.
memory (torch.Tensor):
Memory vectors, with shape (M, B, D), or empty tensor.
left_context_key (torch,Tensor):
@ -744,7 +714,6 @@ class EmformerAttention(nn.Module):
Returns:
A tuple containing 4 tensors:
- output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (1, B, D) or (0, B, D).
- attention key of left context and utterance, which would be cached
for next computation, with shape (L + U, B, D).
- attention value of left context and utterance, which would be
@ -753,28 +722,19 @@ class EmformerAttention(nn.Module):
U = utterance.size(0)
R = right_context.size(0)
L = left_context_key.size(0)
S = summary.size(0)
M = memory.size(0)
# TODO: move it outside
# query = [right context, utterance, summary]
Q = R + U + S
# query = [right context, utterance]
Q = R + U
# key, value = [memory, right context, left context, uttrance]
KV = M + R + L + U
attention_mask = torch.zeros(Q, KV).to(
dtype=torch.bool, device=utterance.device
)
# disallow attention bettween the summary vector with the memory bank
attention_mask[-1, :M] = True
(
output_right_context_utterance,
output_memory,
key,
value,
) = self._forward_impl(
output_right_context_utterance, key, value = self._forward_impl(
utterance,
right_context,
summary,
memory,
attention_mask,
padding_mask=padding_mask,
@ -783,7 +743,6 @@ class EmformerAttention(nn.Module):
)
return (
output_right_context_utterance,
output_memory,
key[M + R :],
value[M + R :],
)
@ -938,49 +897,46 @@ class EmformerEncoderLayer(nn.Module):
self,
right_context_utterance: torch.Tensor,
R: int,
memory: torch.Tensor,
attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""Apply attention module in training and validation mode."""
utterance = right_context_utterance[R:]
right_context = right_context_utterance[:R]
if self.use_memory:
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
memory = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1
)
)[:-1, :, :]
else:
summary = torch.empty(0).to(
memory = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device
)
output_right_context_utterance, output_memory = self.attention(
output_right_context_utterance = self.attention(
utterance=utterance,
right_context=right_context,
summary=summary,
memory=memory,
attention_mask=attention_mask,
padding_mask=padding_mask,
)
return output_right_context_utterance, output_memory
return output_right_context_utterance
def _apply_attention_module_infer(
self,
right_context_utterance: torch.Tensor,
R: int,
memory: torch.Tensor,
attn_cache: List[torch.Tensor],
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Apply attention module in inference mode.
1) Unpack cached states including:
- memory from previous chunks in the lower layer;
- memory from previous chunks;
- attention key and value of left context from preceding
chunk's compuation;
2) Apply attention computation;
3) Update cached attention states including:
- output memory of current chunk in the lower layer;
- memory of current chunk;
- attention key and value in current chunk's computation, which would
be resued in next chunk's computation.
"""
@ -992,23 +948,20 @@ class EmformerEncoderLayer(nn.Module):
left_context_val = attn_cache[2]
if self.use_memory:
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
memory = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1
)
summary = summary[:1]
)[:1, :, :]
else:
summary = torch.empty(0).to(
memory = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device
)
(
output_right_context_utterance,
output_memory,
next_key,
next_val,
) = self.attention.infer(
utterance=utterance,
right_context=right_context,
summary=summary,
memory=pre_memory,
left_context_key=left_context_key,
left_context_val=left_context_val,
@ -1017,17 +970,16 @@ class EmformerEncoderLayer(nn.Module):
attn_cache = self._update_attn_cache(
next_key, next_val, memory, attn_cache
)
return output_right_context_utterance, output_memory, attn_cache
return output_right_context_utterance, attn_cache
def forward(
self,
utterance: torch.Tensor,
right_context: torch.Tensor,
memory: torch.Tensor,
attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training and validation mode.
B: batch size;
@ -1041,20 +993,16 @@ class EmformerEncoderLayer(nn.Module):
Utterance frames, with shape (U, B, D).
right_context (torch.Tensor):
Right context frames, with shape (R, B, D).
memory (torch.Tensor):
Memory elements, with shape (M, B, D).
It is an empty tensor without using memory.
attention_mask (torch.Tensor):
Attention mask for underlying attention module,
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
with shape (Q, KV), where Q = R + U, KV = M + R + U.
padding_mask (torch.Tensor):
Padding mask of ker tensor, with shape (B, KV).
Returns:
A tuple containing 3 tensors:
A tuple containing 2 tensors:
- output utterance, with shape (U, B, D).
- output right context, with shape (R, B, D).
- output memory, with shape (M, B, D).
"""
R = right_context.size(0)
src = torch.cat([right_context, utterance])
@ -1076,8 +1024,8 @@ class EmformerEncoderLayer(nn.Module):
src = src + self.dropout(self.feed_forward_macaron(src))
# emformer attention module
src_att, output_memory = self._apply_attention_module_forward(
src, R, memory, attention_mask, padding_mask=padding_mask
src_att = self._apply_attention_module_forward(
src, R, attention_mask, padding_mask=padding_mask
)
src = src + self.dropout(src_att)
@ -1095,24 +1043,17 @@ class EmformerEncoderLayer(nn.Module):
output_utterance = src[R:]
output_right_context = src[:R]
return output_utterance, output_right_context, output_memory
return output_utterance, output_right_context
@torch.jit.export
def infer(
self,
utterance: torch.Tensor,
right_context: torch.Tensor,
memory: torch.Tensor,
attn_cache: List[torch.Tensor],
conv_cache: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
List[torch.Tensor],
torch.Tensor,
]:
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""Forward pass for inference.
B: batch size;
@ -1126,8 +1067,6 @@ class EmformerEncoderLayer(nn.Module):
Utterance frames, with shape (U, B, D).
right_context (torch.Tensor):
Right context frames, with shape (R, B, D).
memory (torch.Tensor):
Memory elements, with shape (M, B, D).
attn_cache (List[torch.Tensor]):
Cached attention tensors generated in preceding computation,
including memory, key and value of left context.
@ -1140,9 +1079,8 @@ class EmformerEncoderLayer(nn.Module):
(Tensor, Tensor, List[torch.Tensor], Tensor):
- output utterance, with shape (U, B, D);
- output right_context, with shape (R, B, D);
- output memory, with shape (1, B, D) or (0, B, D).
- output state.
- updated conv_cache.
- output attention cache;
- output convolution cache.
"""
R = right_context.size(0)
src = torch.cat([right_context, utterance])
@ -1151,12 +1089,8 @@ class EmformerEncoderLayer(nn.Module):
src = src + self.dropout(self.feed_forward_macaron(src))
# emformer attention module
(
src_att,
output_memory,
attn_cache,
) = self._apply_attention_module_infer(
src, R, memory, attn_cache, padding_mask=padding_mask
src_att, attn_cache = self._apply_attention_module_infer(
src, R, attn_cache, padding_mask=padding_mask
)
src = src + self.dropout(src_att)
@ -1174,7 +1108,6 @@ class EmformerEncoderLayer(nn.Module):
return (
output_utterance,
output_right_context,
output_memory,
attn_cache,
conv_cache,
)
@ -1253,11 +1186,6 @@ class EmformerEncoder(nn.Module):
super().__init__()
self.use_memory = memory_size > 0
self.init_memory_op = nn.AvgPool1d(
kernel_size=chunk_length,
stride=chunk_length,
ceil_mode=True,
)
self.emformer_layers = nn.ModuleList(
[
@ -1358,16 +1286,15 @@ class EmformerEncoder(nn.Module):
R: length of hard-copied right contexts;
U: length of full utterance;
S: length of summary vectors;
M: length of memory vectors;
Q: length of attention query;
KV: length of attention key and value.
The shape of attention mask is (Q, KV).
If self.use_memory is `True`:
query = [right_context, utterance, summary];
query = [right_context, utterance];
key, value = [memory, right_context, utterance];
Q = R + U + S, KV = M + R + U.
Q = R + U, KV = M + R + U.
Otherwise:
query = [right_context, utterance]
key, value = [right_context, utterance]
@ -1378,17 +1305,14 @@ class EmformerEncoder(nn.Module):
r_i: right context that c_i can use;
l_i: left context that c_i can use;
m_i: past memory vectors from previous layer that c_i can use;
s_i: summary vector of c_i.
The target chunk-wise attention is:
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
s_i (in query) -> l_i, c_i, r_i (in key).
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key).
"""
U = utterance.size(0)
num_chunks = math.ceil(U / self.chunk_length)
right_context_mask = []
utterance_mask = []
summary_mask = []
if self.use_memory:
num_cols = 9
@ -1397,9 +1321,6 @@ class EmformerEncoder(nn.Module):
right_context_utterance_cols_mask = [
idx in [1, 4, 7] for idx in range(num_cols)
]
# summary attends to right context, utterance
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
else:
num_cols = 6
# right context and utterance both attend to right context and
@ -1407,7 +1328,6 @@ class EmformerEncoder(nn.Module):
right_context_utterance_cols_mask = [
idx in [1, 4] for idx in range(num_cols)
]
summary_cols_mask = None
masks_to_concat = [right_context_mask, utterance_mask]
for chunk_idx in range(num_chunks):
@ -1432,12 +1352,6 @@ class EmformerEncoder(nn.Module):
)
utterance_mask.append(utterance_mask_block)
if summary_cols_mask is not None:
summary_mask_block = _gen_attention_mask_block(
col_widths, summary_cols_mask, 1, utterance.device
)
summary_mask.append(summary_mask_block)
attention_mask = (
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
).to(torch.bool)
@ -1473,23 +1387,15 @@ class EmformerEncoder(nn.Module):
utterance = x[:U]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance)
memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
:-1
]
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
padding_mask = make_pad_mask(
memory.size(0) + right_context.size(0) + output_lengths
)
M = right_context.size(0) // self.chunk_length - 1
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
output = utterance
for layer in self.emformer_layers:
output, right_context, memory = layer(
output, right_context = layer(
output,
right_context,
memory,
attention_mask,
padding_mask=padding_mask,
warmup=warmup,
@ -1525,7 +1431,6 @@ class EmformerEncoder(nn.Module):
right_context at the end.
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
Cached states containing:
- past_lens: number of past frames for each sample in batch
- attn_caches: attention states from preceding chunk's computation,
where each element corresponds to each emformer layer
- conv_caches: left context for causal convolution, where each
@ -1571,11 +1476,6 @@ class EmformerEncoder(nn.Module):
right_context = x[-self.right_context_length :]
utterance = x[: -self.right_context_length]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
# calcualte padding mask to mask out initial zero caches
chunk_mask = make_pad_mask(output_lengths).to(x.device)
@ -1611,13 +1511,11 @@ class EmformerEncoder(nn.Module):
(
output,
right_context,
memory,
output_attn_cache,
output_conv_cache,
) = layer.infer(
output,
right_context,
memory,
padding_mask=padding_mask,
attn_cache=attn_caches[layer_idx],
conv_cache=conv_caches[layer_idx],