mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
use average value as memory vector for each chunk
This commit is contained in:
parent
1c067e7364
commit
193b44ed7a
@ -537,7 +537,6 @@ class EmformerAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
summary: torch.Tensor,
|
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
@ -550,10 +549,8 @@ class EmformerAttention(nn.Module):
|
|||||||
M = memory.size(0)
|
M = memory.size(0)
|
||||||
scaling = float(self.head_dim) ** -0.5
|
scaling = float(self.head_dim) ** -0.5
|
||||||
|
|
||||||
# compute query with [right_context, utterance, summary].
|
# compute query with [right_context, utterance].
|
||||||
query = self.emb_to_query(
|
query = self.emb_to_query(torch.cat([right_context, utterance]))
|
||||||
torch.cat([right_context, utterance, summary])
|
|
||||||
)
|
|
||||||
# compute key and value with [memory, right_context, utterance].
|
# compute key and value with [memory, right_context, utterance].
|
||||||
key, value = self.emb_to_key_value(
|
key, value = self.emb_to_key_value(
|
||||||
torch.cat([memory, right_context, utterance])
|
torch.cat([memory, right_context, utterance])
|
||||||
@ -593,26 +590,18 @@ class EmformerAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# apply output projection
|
# apply output projection
|
||||||
outputs = self.out_proj(attention)
|
output_right_context_utterance = self.out_proj(attention)
|
||||||
|
|
||||||
output_right_context_utterance = outputs[: R + U]
|
return output_right_context_utterance, key, value
|
||||||
output_memory = outputs[R + U :]
|
|
||||||
if self.tanh_on_mem:
|
|
||||||
output_memory = torch.tanh(output_memory)
|
|
||||||
else:
|
|
||||||
output_memory = torch.clamp(output_memory, min=-10, max=10)
|
|
||||||
|
|
||||||
return output_right_context_utterance, output_memory, key, value
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
summary: torch.Tensor,
|
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
# TODO: Modify docs.
|
# TODO: Modify docs.
|
||||||
"""Forward pass for training and validation mode.
|
"""Forward pass for training and validation mode.
|
||||||
|
|
||||||
@ -620,17 +609,16 @@ class EmformerAttention(nn.Module):
|
|||||||
D: embedding dimension;
|
D: embedding dimension;
|
||||||
R: length of the hard-copied right contexts;
|
R: length of the hard-copied right contexts;
|
||||||
U: length of full utterance;
|
U: length of full utterance;
|
||||||
S: length of summary vectors;
|
|
||||||
M: length of memory vectors.
|
M: length of memory vectors.
|
||||||
|
|
||||||
It computes a `big` attention matrix on full utterance and
|
It computes a `big` attention matrix on full utterance and
|
||||||
then utilizes a pre-computed mask to simulate chunk-wise attention.
|
then utilizes a pre-computed mask to simulate chunk-wise attention.
|
||||||
|
|
||||||
It concatenates three blocks: hard-copied right contexts,
|
It concatenates three blocks: hard-copied right contexts,
|
||||||
full utterance, and summary vectors, as a `big` block,
|
and full utterance, as a `big` block,
|
||||||
to compute the query tensor:
|
to compute the query tensor:
|
||||||
query = [right_context, utterance, summary],
|
query = [right_context, utterance],
|
||||||
with length Q = R + U + S.
|
with length Q = R + U.
|
||||||
It concatenates the three blocks: memory vectors,
|
It concatenates the three blocks: memory vectors,
|
||||||
hard-copied right contexts, and full utterance as another `big` block,
|
hard-copied right contexts, and full utterance as another `big` block,
|
||||||
to compute the key and value tensors:
|
to compute the key and value tensors:
|
||||||
@ -644,10 +632,8 @@ class EmformerAttention(nn.Module):
|
|||||||
r_i: right context that c_i can use;
|
r_i: right context that c_i can use;
|
||||||
l_i: left context that c_i can use;
|
l_i: left context that c_i can use;
|
||||||
m_i: past memory vectors from previous layer that c_i can use;
|
m_i: past memory vectors from previous layer that c_i can use;
|
||||||
s_i: summary vector of c_i;
|
|
||||||
The target chunk-wise attention is:
|
The target chunk-wise attention is:
|
||||||
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
|
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key)
|
||||||
s_i (in query) -> l_i, c_i, r_i (in key).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
@ -655,9 +641,6 @@ class EmformerAttention(nn.Module):
|
|||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Hard-copied right context frames, with shape (R, B, D),
|
Hard-copied right context frames, with shape (R, B, D),
|
||||||
where R = num_chunks * right_context_length
|
where R = num_chunks * right_context_length
|
||||||
summary (torch.Tensor):
|
|
||||||
Summary elements with shape (S, B, D), where S = num_chunks.
|
|
||||||
It is an empty tensor without using memory.
|
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory elements, with shape (M, B, D), where M = num_chunks - 1.
|
Memory elements, with shape (M, B, D), where M = num_chunks - 1.
|
||||||
It is an empty tensor without using memory.
|
It is an empty tensor without using memory.
|
||||||
@ -668,31 +651,22 @@ class EmformerAttention(nn.Module):
|
|||||||
Padding mask of key tensor, with shape (B, KV).
|
Padding mask of key tensor, with shape (B, KV).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 2 tensors:
|
Output of right context and utterance, with shape (R + U, B, D).
|
||||||
- output of right context and utterance, with shape (R + U, B, D).
|
|
||||||
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
|
||||||
"""
|
"""
|
||||||
(
|
output_right_context_utterance, _, _ = self._forward_impl(
|
||||||
output_right_context_utterance,
|
|
||||||
output_memory,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = self._forward_impl(
|
|
||||||
utterance,
|
utterance,
|
||||||
right_context,
|
right_context,
|
||||||
summary,
|
|
||||||
memory,
|
memory,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
return output_right_context_utterance, output_memory[:-1]
|
return output_right_context_utterance
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
summary: torch.Tensor,
|
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
left_context_key: torch.Tensor,
|
left_context_key: torch.Tensor,
|
||||||
left_context_val: torch.Tensor,
|
left_context_val: torch.Tensor,
|
||||||
@ -705,13 +679,12 @@ class EmformerAttention(nn.Module):
|
|||||||
R: length of right context;
|
R: length of right context;
|
||||||
U: length of utterance, i.e., current chunk;
|
U: length of utterance, i.e., current chunk;
|
||||||
L: length of cached left context;
|
L: length of cached left context;
|
||||||
S: length of summary vectors, S = 1;
|
|
||||||
M: length of cached memory vectors.
|
M: length of cached memory vectors.
|
||||||
|
|
||||||
It concatenates the right context, utterance (i.e., current chunk)
|
It concatenates the right context and utterance (i.e., current chunk)
|
||||||
and summary vector of current chunk, to compute the query tensor:
|
of current chunk, to compute the query tensor:
|
||||||
query = [right_context, utterance, summary],
|
query = [right_context, utterance],
|
||||||
with length Q = R + U + S.
|
with length Q = R + U.
|
||||||
It concatenates the memory vectors, right context, left context, and
|
It concatenates the memory vectors, right context, left context, and
|
||||||
current chunk, to compute the key and value tensors:
|
current chunk, to compute the key and value tensors:
|
||||||
key & value = [memory, right_context, left_context, utterance],
|
key & value = [memory, right_context, left_context, utterance],
|
||||||
@ -719,8 +692,7 @@ class EmformerAttention(nn.Module):
|
|||||||
|
|
||||||
The chunk-wise attention is:
|
The chunk-wise attention is:
|
||||||
chunk, right context (in query) ->
|
chunk, right context (in query) ->
|
||||||
left context, chunk, right context, memory vectors (in key);
|
left context, chunk, right context, memory vectors (in key).
|
||||||
summary (in query) -> left context, chunk, right context (in key).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
@ -728,8 +700,6 @@ class EmformerAttention(nn.Module):
|
|||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D),
|
Right context frames, with shape (R, B, D),
|
||||||
where R = right_context_length.
|
where R = right_context_length.
|
||||||
summary (torch.Tensor):
|
|
||||||
Summary vector with shape (1, B, D), or empty tensor.
|
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory vectors, with shape (M, B, D), or empty tensor.
|
Memory vectors, with shape (M, B, D), or empty tensor.
|
||||||
left_context_key (torch,Tensor):
|
left_context_key (torch,Tensor):
|
||||||
@ -744,7 +714,6 @@ class EmformerAttention(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple containing 4 tensors:
|
A tuple containing 4 tensors:
|
||||||
- output of right context and utterance, with shape (R + U, B, D).
|
- output of right context and utterance, with shape (R + U, B, D).
|
||||||
- memory output, with shape (1, B, D) or (0, B, D).
|
|
||||||
- attention key of left context and utterance, which would be cached
|
- attention key of left context and utterance, which would be cached
|
||||||
for next computation, with shape (L + U, B, D).
|
for next computation, with shape (L + U, B, D).
|
||||||
- attention value of left context and utterance, which would be
|
- attention value of left context and utterance, which would be
|
||||||
@ -753,28 +722,19 @@ class EmformerAttention(nn.Module):
|
|||||||
U = utterance.size(0)
|
U = utterance.size(0)
|
||||||
R = right_context.size(0)
|
R = right_context.size(0)
|
||||||
L = left_context_key.size(0)
|
L = left_context_key.size(0)
|
||||||
S = summary.size(0)
|
|
||||||
M = memory.size(0)
|
M = memory.size(0)
|
||||||
|
|
||||||
# TODO: move it outside
|
# query = [right context, utterance]
|
||||||
# query = [right context, utterance, summary]
|
Q = R + U
|
||||||
Q = R + U + S
|
|
||||||
# key, value = [memory, right context, left context, uttrance]
|
# key, value = [memory, right context, left context, uttrance]
|
||||||
KV = M + R + L + U
|
KV = M + R + L + U
|
||||||
attention_mask = torch.zeros(Q, KV).to(
|
attention_mask = torch.zeros(Q, KV).to(
|
||||||
dtype=torch.bool, device=utterance.device
|
dtype=torch.bool, device=utterance.device
|
||||||
)
|
)
|
||||||
# disallow attention bettween the summary vector with the memory bank
|
|
||||||
attention_mask[-1, :M] = True
|
output_right_context_utterance, key, value = self._forward_impl(
|
||||||
(
|
|
||||||
output_right_context_utterance,
|
|
||||||
output_memory,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
) = self._forward_impl(
|
|
||||||
utterance,
|
utterance,
|
||||||
right_context,
|
right_context,
|
||||||
summary,
|
|
||||||
memory,
|
memory,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
@ -783,7 +743,6 @@ class EmformerAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
output_right_context_utterance,
|
output_right_context_utterance,
|
||||||
output_memory,
|
|
||||||
key[M + R :],
|
key[M + R :],
|
||||||
value[M + R :],
|
value[M + R :],
|
||||||
)
|
)
|
||||||
@ -938,49 +897,46 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
right_context_utterance: torch.Tensor,
|
right_context_utterance: torch.Tensor,
|
||||||
R: int,
|
R: int,
|
||||||
memory: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
"""Apply attention module in training and validation mode."""
|
"""Apply attention module in training and validation mode."""
|
||||||
utterance = right_context_utterance[R:]
|
utterance = right_context_utterance[R:]
|
||||||
right_context = right_context_utterance[:R]
|
right_context = right_context_utterance[:R]
|
||||||
|
|
||||||
if self.use_memory:
|
if self.use_memory:
|
||||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
memory = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||||
2, 0, 1
|
2, 0, 1
|
||||||
)
|
)[:-1, :, :]
|
||||||
else:
|
else:
|
||||||
summary = torch.empty(0).to(
|
memory = torch.empty(0).to(
|
||||||
dtype=utterance.dtype, device=utterance.device
|
dtype=utterance.dtype, device=utterance.device
|
||||||
)
|
)
|
||||||
output_right_context_utterance, output_memory = self.attention(
|
output_right_context_utterance = self.attention(
|
||||||
utterance=utterance,
|
utterance=utterance,
|
||||||
right_context=right_context,
|
right_context=right_context,
|
||||||
summary=summary,
|
|
||||||
memory=memory,
|
memory=memory,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_right_context_utterance, output_memory
|
return output_right_context_utterance
|
||||||
|
|
||||||
def _apply_attention_module_infer(
|
def _apply_attention_module_infer(
|
||||||
self,
|
self,
|
||||||
right_context_utterance: torch.Tensor,
|
right_context_utterance: torch.Tensor,
|
||||||
R: int,
|
R: int,
|
||||||
memory: torch.Tensor,
|
|
||||||
attn_cache: List[torch.Tensor],
|
attn_cache: List[torch.Tensor],
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Apply attention module in inference mode.
|
"""Apply attention module in inference mode.
|
||||||
1) Unpack cached states including:
|
1) Unpack cached states including:
|
||||||
- memory from previous chunks in the lower layer;
|
- memory from previous chunks;
|
||||||
- attention key and value of left context from preceding
|
- attention key and value of left context from preceding
|
||||||
chunk's compuation;
|
chunk's compuation;
|
||||||
2) Apply attention computation;
|
2) Apply attention computation;
|
||||||
3) Update cached attention states including:
|
3) Update cached attention states including:
|
||||||
- output memory of current chunk in the lower layer;
|
- memory of current chunk;
|
||||||
- attention key and value in current chunk's computation, which would
|
- attention key and value in current chunk's computation, which would
|
||||||
be resued in next chunk's computation.
|
be resued in next chunk's computation.
|
||||||
"""
|
"""
|
||||||
@ -992,23 +948,20 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
left_context_val = attn_cache[2]
|
left_context_val = attn_cache[2]
|
||||||
|
|
||||||
if self.use_memory:
|
if self.use_memory:
|
||||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
memory = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||||
2, 0, 1
|
2, 0, 1
|
||||||
)
|
)[:1, :, :]
|
||||||
summary = summary[:1]
|
|
||||||
else:
|
else:
|
||||||
summary = torch.empty(0).to(
|
memory = torch.empty(0).to(
|
||||||
dtype=utterance.dtype, device=utterance.device
|
dtype=utterance.dtype, device=utterance.device
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
output_right_context_utterance,
|
output_right_context_utterance,
|
||||||
output_memory,
|
|
||||||
next_key,
|
next_key,
|
||||||
next_val,
|
next_val,
|
||||||
) = self.attention.infer(
|
) = self.attention.infer(
|
||||||
utterance=utterance,
|
utterance=utterance,
|
||||||
right_context=right_context,
|
right_context=right_context,
|
||||||
summary=summary,
|
|
||||||
memory=pre_memory,
|
memory=pre_memory,
|
||||||
left_context_key=left_context_key,
|
left_context_key=left_context_key,
|
||||||
left_context_val=left_context_val,
|
left_context_val=left_context_val,
|
||||||
@ -1017,17 +970,16 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
attn_cache = self._update_attn_cache(
|
attn_cache = self._update_attn_cache(
|
||||||
next_key, next_val, memory, attn_cache
|
next_key, next_val, memory, attn_cache
|
||||||
)
|
)
|
||||||
return output_right_context_utterance, output_memory, attn_cache
|
return output_right_context_utterance, attn_cache
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
r"""Forward pass for training and validation mode.
|
r"""Forward pass for training and validation mode.
|
||||||
|
|
||||||
B: batch size;
|
B: batch size;
|
||||||
@ -1041,20 +993,16 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
Utterance frames, with shape (U, B, D).
|
Utterance frames, with shape (U, B, D).
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
memory (torch.Tensor):
|
|
||||||
Memory elements, with shape (M, B, D).
|
|
||||||
It is an empty tensor without using memory.
|
|
||||||
attention_mask (torch.Tensor):
|
attention_mask (torch.Tensor):
|
||||||
Attention mask for underlying attention module,
|
Attention mask for underlying attention module,
|
||||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
with shape (Q, KV), where Q = R + U, KV = M + R + U.
|
||||||
padding_mask (torch.Tensor):
|
padding_mask (torch.Tensor):
|
||||||
Padding mask of ker tensor, with shape (B, KV).
|
Padding mask of ker tensor, with shape (B, KV).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 3 tensors:
|
A tuple containing 2 tensors:
|
||||||
- output utterance, with shape (U, B, D).
|
- output utterance, with shape (U, B, D).
|
||||||
- output right context, with shape (R, B, D).
|
- output right context, with shape (R, B, D).
|
||||||
- output memory, with shape (M, B, D).
|
|
||||||
"""
|
"""
|
||||||
R = right_context.size(0)
|
R = right_context.size(0)
|
||||||
src = torch.cat([right_context, utterance])
|
src = torch.cat([right_context, utterance])
|
||||||
@ -1076,8 +1024,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
# emformer attention module
|
# emformer attention module
|
||||||
src_att, output_memory = self._apply_attention_module_forward(
|
src_att = self._apply_attention_module_forward(
|
||||||
src, R, memory, attention_mask, padding_mask=padding_mask
|
src, R, attention_mask, padding_mask=padding_mask
|
||||||
)
|
)
|
||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
@ -1095,24 +1043,17 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
output_utterance = src[R:]
|
output_utterance = src[R:]
|
||||||
output_right_context = src[:R]
|
output_right_context = src[:R]
|
||||||
return output_utterance, output_right_context, output_memory
|
return output_utterance, output_right_context
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
|
||||||
attn_cache: List[torch.Tensor],
|
attn_cache: List[torch.Tensor],
|
||||||
conv_cache: torch.Tensor,
|
conv_cache: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||||
torch.Tensor,
|
|
||||||
torch.Tensor,
|
|
||||||
torch.Tensor,
|
|
||||||
List[torch.Tensor],
|
|
||||||
torch.Tensor,
|
|
||||||
]:
|
|
||||||
"""Forward pass for inference.
|
"""Forward pass for inference.
|
||||||
|
|
||||||
B: batch size;
|
B: batch size;
|
||||||
@ -1126,8 +1067,6 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
Utterance frames, with shape (U, B, D).
|
Utterance frames, with shape (U, B, D).
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
memory (torch.Tensor):
|
|
||||||
Memory elements, with shape (M, B, D).
|
|
||||||
attn_cache (List[torch.Tensor]):
|
attn_cache (List[torch.Tensor]):
|
||||||
Cached attention tensors generated in preceding computation,
|
Cached attention tensors generated in preceding computation,
|
||||||
including memory, key and value of left context.
|
including memory, key and value of left context.
|
||||||
@ -1140,9 +1079,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||||
- output utterance, with shape (U, B, D);
|
- output utterance, with shape (U, B, D);
|
||||||
- output right_context, with shape (R, B, D);
|
- output right_context, with shape (R, B, D);
|
||||||
- output memory, with shape (1, B, D) or (0, B, D).
|
- output attention cache;
|
||||||
- output state.
|
- output convolution cache.
|
||||||
- updated conv_cache.
|
|
||||||
"""
|
"""
|
||||||
R = right_context.size(0)
|
R = right_context.size(0)
|
||||||
src = torch.cat([right_context, utterance])
|
src = torch.cat([right_context, utterance])
|
||||||
@ -1151,12 +1089,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
# emformer attention module
|
# emformer attention module
|
||||||
(
|
src_att, attn_cache = self._apply_attention_module_infer(
|
||||||
src_att,
|
src, R, attn_cache, padding_mask=padding_mask
|
||||||
output_memory,
|
|
||||||
attn_cache,
|
|
||||||
) = self._apply_attention_module_infer(
|
|
||||||
src, R, memory, attn_cache, padding_mask=padding_mask
|
|
||||||
)
|
)
|
||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
@ -1174,7 +1108,6 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
return (
|
return (
|
||||||
output_utterance,
|
output_utterance,
|
||||||
output_right_context,
|
output_right_context,
|
||||||
output_memory,
|
|
||||||
attn_cache,
|
attn_cache,
|
||||||
conv_cache,
|
conv_cache,
|
||||||
)
|
)
|
||||||
@ -1253,11 +1186,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.use_memory = memory_size > 0
|
self.use_memory = memory_size > 0
|
||||||
self.init_memory_op = nn.AvgPool1d(
|
|
||||||
kernel_size=chunk_length,
|
|
||||||
stride=chunk_length,
|
|
||||||
ceil_mode=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.emformer_layers = nn.ModuleList(
|
self.emformer_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -1358,16 +1286,15 @@ class EmformerEncoder(nn.Module):
|
|||||||
|
|
||||||
R: length of hard-copied right contexts;
|
R: length of hard-copied right contexts;
|
||||||
U: length of full utterance;
|
U: length of full utterance;
|
||||||
S: length of summary vectors;
|
|
||||||
M: length of memory vectors;
|
M: length of memory vectors;
|
||||||
Q: length of attention query;
|
Q: length of attention query;
|
||||||
KV: length of attention key and value.
|
KV: length of attention key and value.
|
||||||
|
|
||||||
The shape of attention mask is (Q, KV).
|
The shape of attention mask is (Q, KV).
|
||||||
If self.use_memory is `True`:
|
If self.use_memory is `True`:
|
||||||
query = [right_context, utterance, summary];
|
query = [right_context, utterance];
|
||||||
key, value = [memory, right_context, utterance];
|
key, value = [memory, right_context, utterance];
|
||||||
Q = R + U + S, KV = M + R + U.
|
Q = R + U, KV = M + R + U.
|
||||||
Otherwise:
|
Otherwise:
|
||||||
query = [right_context, utterance]
|
query = [right_context, utterance]
|
||||||
key, value = [right_context, utterance]
|
key, value = [right_context, utterance]
|
||||||
@ -1378,17 +1305,14 @@ class EmformerEncoder(nn.Module):
|
|||||||
r_i: right context that c_i can use;
|
r_i: right context that c_i can use;
|
||||||
l_i: left context that c_i can use;
|
l_i: left context that c_i can use;
|
||||||
m_i: past memory vectors from previous layer that c_i can use;
|
m_i: past memory vectors from previous layer that c_i can use;
|
||||||
s_i: summary vector of c_i.
|
|
||||||
The target chunk-wise attention is:
|
The target chunk-wise attention is:
|
||||||
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
|
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key).
|
||||||
s_i (in query) -> l_i, c_i, r_i (in key).
|
|
||||||
"""
|
"""
|
||||||
U = utterance.size(0)
|
U = utterance.size(0)
|
||||||
num_chunks = math.ceil(U / self.chunk_length)
|
num_chunks = math.ceil(U / self.chunk_length)
|
||||||
|
|
||||||
right_context_mask = []
|
right_context_mask = []
|
||||||
utterance_mask = []
|
utterance_mask = []
|
||||||
summary_mask = []
|
|
||||||
|
|
||||||
if self.use_memory:
|
if self.use_memory:
|
||||||
num_cols = 9
|
num_cols = 9
|
||||||
@ -1397,9 +1321,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
right_context_utterance_cols_mask = [
|
right_context_utterance_cols_mask = [
|
||||||
idx in [1, 4, 7] for idx in range(num_cols)
|
idx in [1, 4, 7] for idx in range(num_cols)
|
||||||
]
|
]
|
||||||
# summary attends to right context, utterance
|
|
||||||
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
|
||||||
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
|
|
||||||
else:
|
else:
|
||||||
num_cols = 6
|
num_cols = 6
|
||||||
# right context and utterance both attend to right context and
|
# right context and utterance both attend to right context and
|
||||||
@ -1407,7 +1328,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
right_context_utterance_cols_mask = [
|
right_context_utterance_cols_mask = [
|
||||||
idx in [1, 4] for idx in range(num_cols)
|
idx in [1, 4] for idx in range(num_cols)
|
||||||
]
|
]
|
||||||
summary_cols_mask = None
|
|
||||||
masks_to_concat = [right_context_mask, utterance_mask]
|
masks_to_concat = [right_context_mask, utterance_mask]
|
||||||
|
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
@ -1432,12 +1352,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
utterance_mask.append(utterance_mask_block)
|
utterance_mask.append(utterance_mask_block)
|
||||||
|
|
||||||
if summary_cols_mask is not None:
|
|
||||||
summary_mask_block = _gen_attention_mask_block(
|
|
||||||
col_widths, summary_cols_mask, 1, utterance.device
|
|
||||||
)
|
|
||||||
summary_mask.append(summary_mask_block)
|
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
|
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
|
||||||
).to(torch.bool)
|
).to(torch.bool)
|
||||||
@ -1473,23 +1387,15 @@ class EmformerEncoder(nn.Module):
|
|||||||
utterance = x[:U]
|
utterance = x[:U]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
attention_mask = self._gen_attention_mask(utterance)
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
memory = (
|
|
||||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
M = right_context.size(0) // self.chunk_length - 1
|
||||||
:-1
|
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
|
||||||
]
|
|
||||||
if self.use_memory
|
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
|
||||||
)
|
|
||||||
padding_mask = make_pad_mask(
|
|
||||||
memory.size(0) + right_context.size(0) + output_lengths
|
|
||||||
)
|
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
for layer in self.emformer_layers:
|
for layer in self.emformer_layers:
|
||||||
output, right_context, memory = layer(
|
output, right_context = layer(
|
||||||
output,
|
output,
|
||||||
right_context,
|
right_context,
|
||||||
memory,
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
@ -1525,7 +1431,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
right_context at the end.
|
right_context at the end.
|
||||||
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
|
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
|
||||||
Cached states containing:
|
Cached states containing:
|
||||||
- past_lens: number of past frames for each sample in batch
|
|
||||||
- attn_caches: attention states from preceding chunk's computation,
|
- attn_caches: attention states from preceding chunk's computation,
|
||||||
where each element corresponds to each emformer layer
|
where each element corresponds to each emformer layer
|
||||||
- conv_caches: left context for causal convolution, where each
|
- conv_caches: left context for causal convolution, where each
|
||||||
@ -1571,11 +1476,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
right_context = x[-self.right_context_length :]
|
right_context = x[-self.right_context_length :]
|
||||||
utterance = x[: -self.right_context_length]
|
utterance = x[: -self.right_context_length]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
memory = (
|
|
||||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
|
||||||
if self.use_memory
|
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
# calcualte padding mask to mask out initial zero caches
|
# calcualte padding mask to mask out initial zero caches
|
||||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||||
@ -1611,13 +1511,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
(
|
(
|
||||||
output,
|
output,
|
||||||
right_context,
|
right_context,
|
||||||
memory,
|
|
||||||
output_attn_cache,
|
output_attn_cache,
|
||||||
output_conv_cache,
|
output_conv_cache,
|
||||||
) = layer.infer(
|
) = layer.infer(
|
||||||
output,
|
output,
|
||||||
right_context,
|
right_context,
|
||||||
memory,
|
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
attn_cache=attn_caches[layer_idx],
|
attn_cache=attn_caches[layer_idx],
|
||||||
conv_cache=conv_caches[layer_idx],
|
conv_cache=conv_caches[layer_idx],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user