use average value as memory vector for each chunk

This commit is contained in:
yaozengwei 2022-06-13 22:14:24 +08:00
parent 1c067e7364
commit 193b44ed7a

View File

@ -537,7 +537,6 @@ class EmformerAttention(nn.Module):
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
@ -550,10 +549,8 @@ class EmformerAttention(nn.Module):
M = memory.size(0) M = memory.size(0)
scaling = float(self.head_dim) ** -0.5 scaling = float(self.head_dim) ** -0.5
# compute query with [right_context, utterance, summary]. # compute query with [right_context, utterance].
query = self.emb_to_query( query = self.emb_to_query(torch.cat([right_context, utterance]))
torch.cat([right_context, utterance, summary])
)
# compute key and value with [memory, right_context, utterance]. # compute key and value with [memory, right_context, utterance].
key, value = self.emb_to_key_value( key, value = self.emb_to_key_value(
torch.cat([memory, right_context, utterance]) torch.cat([memory, right_context, utterance])
@ -593,26 +590,18 @@ class EmformerAttention(nn.Module):
) )
# apply output projection # apply output projection
outputs = self.out_proj(attention) output_right_context_utterance = self.out_proj(attention)
output_right_context_utterance = outputs[: R + U] return output_right_context_utterance, key, value
output_memory = outputs[R + U :]
if self.tanh_on_mem:
output_memory = torch.tanh(output_memory)
else:
output_memory = torch.clamp(output_memory, min=-10, max=10)
return output_right_context_utterance, output_memory, key, value
def forward( def forward(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
# TODO: Modify docs. # TODO: Modify docs.
"""Forward pass for training and validation mode. """Forward pass for training and validation mode.
@ -620,17 +609,16 @@ class EmformerAttention(nn.Module):
D: embedding dimension; D: embedding dimension;
R: length of the hard-copied right contexts; R: length of the hard-copied right contexts;
U: length of full utterance; U: length of full utterance;
S: length of summary vectors;
M: length of memory vectors. M: length of memory vectors.
It computes a `big` attention matrix on full utterance and It computes a `big` attention matrix on full utterance and
then utilizes a pre-computed mask to simulate chunk-wise attention. then utilizes a pre-computed mask to simulate chunk-wise attention.
It concatenates three blocks: hard-copied right contexts, It concatenates three blocks: hard-copied right contexts,
full utterance, and summary vectors, as a `big` block, and full utterance, as a `big` block,
to compute the query tensor: to compute the query tensor:
query = [right_context, utterance, summary], query = [right_context, utterance],
with length Q = R + U + S. with length Q = R + U.
It concatenates the three blocks: memory vectors, It concatenates the three blocks: memory vectors,
hard-copied right contexts, and full utterance as another `big` block, hard-copied right contexts, and full utterance as another `big` block,
to compute the key and value tensors: to compute the key and value tensors:
@ -644,10 +632,8 @@ class EmformerAttention(nn.Module):
r_i: right context that c_i can use; r_i: right context that c_i can use;
l_i: left context that c_i can use; l_i: left context that c_i can use;
m_i: past memory vectors from previous layer that c_i can use; m_i: past memory vectors from previous layer that c_i can use;
s_i: summary vector of c_i;
The target chunk-wise attention is: The target chunk-wise attention is:
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key); c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key)
s_i (in query) -> l_i, c_i, r_i (in key).
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
@ -655,9 +641,6 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor): right_context (torch.Tensor):
Hard-copied right context frames, with shape (R, B, D), Hard-copied right context frames, with shape (R, B, D),
where R = num_chunks * right_context_length where R = num_chunks * right_context_length
summary (torch.Tensor):
Summary elements with shape (S, B, D), where S = num_chunks.
It is an empty tensor without using memory.
memory (torch.Tensor): memory (torch.Tensor):
Memory elements, with shape (M, B, D), where M = num_chunks - 1. Memory elements, with shape (M, B, D), where M = num_chunks - 1.
It is an empty tensor without using memory. It is an empty tensor without using memory.
@ -668,31 +651,22 @@ class EmformerAttention(nn.Module):
Padding mask of key tensor, with shape (B, KV). Padding mask of key tensor, with shape (B, KV).
Returns: Returns:
A tuple containing 2 tensors: Output of right context and utterance, with shape (R + U, B, D).
- output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
""" """
( output_right_context_utterance, _, _ = self._forward_impl(
output_right_context_utterance,
output_memory,
_,
_,
) = self._forward_impl(
utterance, utterance,
right_context, right_context,
summary,
memory, memory,
attention_mask, attention_mask,
padding_mask=padding_mask, padding_mask=padding_mask,
) )
return output_right_context_utterance, output_memory[:-1] return output_right_context_utterance
@torch.jit.export @torch.jit.export
def infer( def infer(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
summary: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
left_context_key: torch.Tensor, left_context_key: torch.Tensor,
left_context_val: torch.Tensor, left_context_val: torch.Tensor,
@ -705,13 +679,12 @@ class EmformerAttention(nn.Module):
R: length of right context; R: length of right context;
U: length of utterance, i.e., current chunk; U: length of utterance, i.e., current chunk;
L: length of cached left context; L: length of cached left context;
S: length of summary vectors, S = 1;
M: length of cached memory vectors. M: length of cached memory vectors.
It concatenates the right context, utterance (i.e., current chunk) It concatenates the right context and utterance (i.e., current chunk)
and summary vector of current chunk, to compute the query tensor: of current chunk, to compute the query tensor:
query = [right_context, utterance, summary], query = [right_context, utterance],
with length Q = R + U + S. with length Q = R + U.
It concatenates the memory vectors, right context, left context, and It concatenates the memory vectors, right context, left context, and
current chunk, to compute the key and value tensors: current chunk, to compute the key and value tensors:
key & value = [memory, right_context, left_context, utterance], key & value = [memory, right_context, left_context, utterance],
@ -719,8 +692,7 @@ class EmformerAttention(nn.Module):
The chunk-wise attention is: The chunk-wise attention is:
chunk, right context (in query) -> chunk, right context (in query) ->
left context, chunk, right context, memory vectors (in key); left context, chunk, right context, memory vectors (in key).
summary (in query) -> left context, chunk, right context (in key).
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
@ -728,8 +700,6 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D), Right context frames, with shape (R, B, D),
where R = right_context_length. where R = right_context_length.
summary (torch.Tensor):
Summary vector with shape (1, B, D), or empty tensor.
memory (torch.Tensor): memory (torch.Tensor):
Memory vectors, with shape (M, B, D), or empty tensor. Memory vectors, with shape (M, B, D), or empty tensor.
left_context_key (torch,Tensor): left_context_key (torch,Tensor):
@ -744,7 +714,6 @@ class EmformerAttention(nn.Module):
Returns: Returns:
A tuple containing 4 tensors: A tuple containing 4 tensors:
- output of right context and utterance, with shape (R + U, B, D). - output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (1, B, D) or (0, B, D).
- attention key of left context and utterance, which would be cached - attention key of left context and utterance, which would be cached
for next computation, with shape (L + U, B, D). for next computation, with shape (L + U, B, D).
- attention value of left context and utterance, which would be - attention value of left context and utterance, which would be
@ -753,28 +722,19 @@ class EmformerAttention(nn.Module):
U = utterance.size(0) U = utterance.size(0)
R = right_context.size(0) R = right_context.size(0)
L = left_context_key.size(0) L = left_context_key.size(0)
S = summary.size(0)
M = memory.size(0) M = memory.size(0)
# TODO: move it outside # query = [right context, utterance]
# query = [right context, utterance, summary] Q = R + U
Q = R + U + S
# key, value = [memory, right context, left context, uttrance] # key, value = [memory, right context, left context, uttrance]
KV = M + R + L + U KV = M + R + L + U
attention_mask = torch.zeros(Q, KV).to( attention_mask = torch.zeros(Q, KV).to(
dtype=torch.bool, device=utterance.device dtype=torch.bool, device=utterance.device
) )
# disallow attention bettween the summary vector with the memory bank
attention_mask[-1, :M] = True output_right_context_utterance, key, value = self._forward_impl(
(
output_right_context_utterance,
output_memory,
key,
value,
) = self._forward_impl(
utterance, utterance,
right_context, right_context,
summary,
memory, memory,
attention_mask, attention_mask,
padding_mask=padding_mask, padding_mask=padding_mask,
@ -783,7 +743,6 @@ class EmformerAttention(nn.Module):
) )
return ( return (
output_right_context_utterance, output_right_context_utterance,
output_memory,
key[M + R :], key[M + R :],
value[M + R :], value[M + R :],
) )
@ -938,49 +897,46 @@ class EmformerEncoderLayer(nn.Module):
self, self,
right_context_utterance: torch.Tensor, right_context_utterance: torch.Tensor,
R: int, R: int,
memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
"""Apply attention module in training and validation mode.""" """Apply attention module in training and validation mode."""
utterance = right_context_utterance[R:] utterance = right_context_utterance[R:]
right_context = right_context_utterance[:R] right_context = right_context_utterance[:R]
if self.use_memory: if self.use_memory:
summary = self.summary_op(utterance.permute(1, 2, 0)).permute( memory = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1 2, 0, 1
) )[:-1, :, :]
else: else:
summary = torch.empty(0).to( memory = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device dtype=utterance.dtype, device=utterance.device
) )
output_right_context_utterance, output_memory = self.attention( output_right_context_utterance = self.attention(
utterance=utterance, utterance=utterance,
right_context=right_context, right_context=right_context,
summary=summary,
memory=memory, memory=memory,
attention_mask=attention_mask, attention_mask=attention_mask,
padding_mask=padding_mask, padding_mask=padding_mask,
) )
return output_right_context_utterance, output_memory return output_right_context_utterance
def _apply_attention_module_infer( def _apply_attention_module_infer(
self, self,
right_context_utterance: torch.Tensor, right_context_utterance: torch.Tensor,
R: int, R: int,
memory: torch.Tensor,
attn_cache: List[torch.Tensor], attn_cache: List[torch.Tensor],
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Apply attention module in inference mode. """Apply attention module in inference mode.
1) Unpack cached states including: 1) Unpack cached states including:
- memory from previous chunks in the lower layer; - memory from previous chunks;
- attention key and value of left context from preceding - attention key and value of left context from preceding
chunk's compuation; chunk's compuation;
2) Apply attention computation; 2) Apply attention computation;
3) Update cached attention states including: 3) Update cached attention states including:
- output memory of current chunk in the lower layer; - memory of current chunk;
- attention key and value in current chunk's computation, which would - attention key and value in current chunk's computation, which would
be resued in next chunk's computation. be resued in next chunk's computation.
""" """
@ -992,23 +948,20 @@ class EmformerEncoderLayer(nn.Module):
left_context_val = attn_cache[2] left_context_val = attn_cache[2]
if self.use_memory: if self.use_memory:
summary = self.summary_op(utterance.permute(1, 2, 0)).permute( memory = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1 2, 0, 1
) )[:1, :, :]
summary = summary[:1]
else: else:
summary = torch.empty(0).to( memory = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device dtype=utterance.dtype, device=utterance.device
) )
( (
output_right_context_utterance, output_right_context_utterance,
output_memory,
next_key, next_key,
next_val, next_val,
) = self.attention.infer( ) = self.attention.infer(
utterance=utterance, utterance=utterance,
right_context=right_context, right_context=right_context,
summary=summary,
memory=pre_memory, memory=pre_memory,
left_context_key=left_context_key, left_context_key=left_context_key,
left_context_val=left_context_val, left_context_val=left_context_val,
@ -1017,17 +970,16 @@ class EmformerEncoderLayer(nn.Module):
attn_cache = self._update_attn_cache( attn_cache = self._update_attn_cache(
next_key, next_val, memory, attn_cache next_key, next_val, memory, attn_cache
) )
return output_right_context_utterance, output_memory, attn_cache return output_right_context_utterance, attn_cache
def forward( def forward(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training and validation mode. r"""Forward pass for training and validation mode.
B: batch size; B: batch size;
@ -1041,20 +993,16 @@ class EmformerEncoderLayer(nn.Module):
Utterance frames, with shape (U, B, D). Utterance frames, with shape (U, B, D).
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D). Right context frames, with shape (R, B, D).
memory (torch.Tensor):
Memory elements, with shape (M, B, D).
It is an empty tensor without using memory.
attention_mask (torch.Tensor): attention_mask (torch.Tensor):
Attention mask for underlying attention module, Attention mask for underlying attention module,
with shape (Q, KV), where Q = R + U + S, KV = M + R + U. with shape (Q, KV), where Q = R + U, KV = M + R + U.
padding_mask (torch.Tensor): padding_mask (torch.Tensor):
Padding mask of ker tensor, with shape (B, KV). Padding mask of ker tensor, with shape (B, KV).
Returns: Returns:
A tuple containing 3 tensors: A tuple containing 2 tensors:
- output utterance, with shape (U, B, D). - output utterance, with shape (U, B, D).
- output right context, with shape (R, B, D). - output right context, with shape (R, B, D).
- output memory, with shape (M, B, D).
""" """
R = right_context.size(0) R = right_context.size(0)
src = torch.cat([right_context, utterance]) src = torch.cat([right_context, utterance])
@ -1076,8 +1024,8 @@ class EmformerEncoderLayer(nn.Module):
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
# emformer attention module # emformer attention module
src_att, output_memory = self._apply_attention_module_forward( src_att = self._apply_attention_module_forward(
src, R, memory, attention_mask, padding_mask=padding_mask src, R, attention_mask, padding_mask=padding_mask
) )
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
@ -1095,24 +1043,17 @@ class EmformerEncoderLayer(nn.Module):
output_utterance = src[R:] output_utterance = src[R:]
output_right_context = src[:R] output_right_context = src[:R]
return output_utterance, output_right_context, output_memory return output_utterance, output_right_context
@torch.jit.export @torch.jit.export
def infer( def infer(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
memory: torch.Tensor,
attn_cache: List[torch.Tensor], attn_cache: List[torch.Tensor],
conv_cache: torch.Tensor, conv_cache: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
torch.Tensor,
torch.Tensor,
torch.Tensor,
List[torch.Tensor],
torch.Tensor,
]:
"""Forward pass for inference. """Forward pass for inference.
B: batch size; B: batch size;
@ -1126,8 +1067,6 @@ class EmformerEncoderLayer(nn.Module):
Utterance frames, with shape (U, B, D). Utterance frames, with shape (U, B, D).
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D). Right context frames, with shape (R, B, D).
memory (torch.Tensor):
Memory elements, with shape (M, B, D).
attn_cache (List[torch.Tensor]): attn_cache (List[torch.Tensor]):
Cached attention tensors generated in preceding computation, Cached attention tensors generated in preceding computation,
including memory, key and value of left context. including memory, key and value of left context.
@ -1140,9 +1079,8 @@ class EmformerEncoderLayer(nn.Module):
(Tensor, Tensor, List[torch.Tensor], Tensor): (Tensor, Tensor, List[torch.Tensor], Tensor):
- output utterance, with shape (U, B, D); - output utterance, with shape (U, B, D);
- output right_context, with shape (R, B, D); - output right_context, with shape (R, B, D);
- output memory, with shape (1, B, D) or (0, B, D). - output attention cache;
- output state. - output convolution cache.
- updated conv_cache.
""" """
R = right_context.size(0) R = right_context.size(0)
src = torch.cat([right_context, utterance]) src = torch.cat([right_context, utterance])
@ -1151,12 +1089,8 @@ class EmformerEncoderLayer(nn.Module):
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
# emformer attention module # emformer attention module
( src_att, attn_cache = self._apply_attention_module_infer(
src_att, src, R, attn_cache, padding_mask=padding_mask
output_memory,
attn_cache,
) = self._apply_attention_module_infer(
src, R, memory, attn_cache, padding_mask=padding_mask
) )
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
@ -1174,7 +1108,6 @@ class EmformerEncoderLayer(nn.Module):
return ( return (
output_utterance, output_utterance,
output_right_context, output_right_context,
output_memory,
attn_cache, attn_cache,
conv_cache, conv_cache,
) )
@ -1253,11 +1186,6 @@ class EmformerEncoder(nn.Module):
super().__init__() super().__init__()
self.use_memory = memory_size > 0 self.use_memory = memory_size > 0
self.init_memory_op = nn.AvgPool1d(
kernel_size=chunk_length,
stride=chunk_length,
ceil_mode=True,
)
self.emformer_layers = nn.ModuleList( self.emformer_layers = nn.ModuleList(
[ [
@ -1358,16 +1286,15 @@ class EmformerEncoder(nn.Module):
R: length of hard-copied right contexts; R: length of hard-copied right contexts;
U: length of full utterance; U: length of full utterance;
S: length of summary vectors;
M: length of memory vectors; M: length of memory vectors;
Q: length of attention query; Q: length of attention query;
KV: length of attention key and value. KV: length of attention key and value.
The shape of attention mask is (Q, KV). The shape of attention mask is (Q, KV).
If self.use_memory is `True`: If self.use_memory is `True`:
query = [right_context, utterance, summary]; query = [right_context, utterance];
key, value = [memory, right_context, utterance]; key, value = [memory, right_context, utterance];
Q = R + U + S, KV = M + R + U. Q = R + U, KV = M + R + U.
Otherwise: Otherwise:
query = [right_context, utterance] query = [right_context, utterance]
key, value = [right_context, utterance] key, value = [right_context, utterance]
@ -1378,17 +1305,14 @@ class EmformerEncoder(nn.Module):
r_i: right context that c_i can use; r_i: right context that c_i can use;
l_i: left context that c_i can use; l_i: left context that c_i can use;
m_i: past memory vectors from previous layer that c_i can use; m_i: past memory vectors from previous layer that c_i can use;
s_i: summary vector of c_i.
The target chunk-wise attention is: The target chunk-wise attention is:
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key); c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key).
s_i (in query) -> l_i, c_i, r_i (in key).
""" """
U = utterance.size(0) U = utterance.size(0)
num_chunks = math.ceil(U / self.chunk_length) num_chunks = math.ceil(U / self.chunk_length)
right_context_mask = [] right_context_mask = []
utterance_mask = [] utterance_mask = []
summary_mask = []
if self.use_memory: if self.use_memory:
num_cols = 9 num_cols = 9
@ -1397,9 +1321,6 @@ class EmformerEncoder(nn.Module):
right_context_utterance_cols_mask = [ right_context_utterance_cols_mask = [
idx in [1, 4, 7] for idx in range(num_cols) idx in [1, 4, 7] for idx in range(num_cols)
] ]
# summary attends to right context, utterance
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
else: else:
num_cols = 6 num_cols = 6
# right context and utterance both attend to right context and # right context and utterance both attend to right context and
@ -1407,8 +1328,7 @@ class EmformerEncoder(nn.Module):
right_context_utterance_cols_mask = [ right_context_utterance_cols_mask = [
idx in [1, 4] for idx in range(num_cols) idx in [1, 4] for idx in range(num_cols)
] ]
summary_cols_mask = None masks_to_concat = [right_context_mask, utterance_mask]
masks_to_concat = [right_context_mask, utterance_mask]
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
col_widths = self._gen_attention_mask_col_widths(chunk_idx, U) col_widths = self._gen_attention_mask_col_widths(chunk_idx, U)
@ -1432,12 +1352,6 @@ class EmformerEncoder(nn.Module):
) )
utterance_mask.append(utterance_mask_block) utterance_mask.append(utterance_mask_block)
if summary_cols_mask is not None:
summary_mask_block = _gen_attention_mask_block(
col_widths, summary_cols_mask, 1, utterance.device
)
summary_mask.append(summary_mask_block)
attention_mask = ( attention_mask = (
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat]) 1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
).to(torch.bool) ).to(torch.bool)
@ -1473,23 +1387,15 @@ class EmformerEncoder(nn.Module):
utterance = x[:U] utterance = x[:U]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance) attention_mask = self._gen_attention_mask(utterance)
memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ M = right_context.size(0) // self.chunk_length - 1
:-1 padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
]
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
padding_mask = make_pad_mask(
memory.size(0) + right_context.size(0) + output_lengths
)
output = utterance output = utterance
for layer in self.emformer_layers: for layer in self.emformer_layers:
output, right_context, memory = layer( output, right_context = layer(
output, output,
right_context, right_context,
memory,
attention_mask, attention_mask,
padding_mask=padding_mask, padding_mask=padding_mask,
warmup=warmup, warmup=warmup,
@ -1525,7 +1431,6 @@ class EmformerEncoder(nn.Module):
right_context at the end. right_context at the end.
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
Cached states containing: Cached states containing:
- past_lens: number of past frames for each sample in batch
- attn_caches: attention states from preceding chunk's computation, - attn_caches: attention states from preceding chunk's computation,
where each element corresponds to each emformer layer where each element corresponds to each emformer layer
- conv_caches: left context for causal convolution, where each - conv_caches: left context for causal convolution, where each
@ -1571,11 +1476,6 @@ class EmformerEncoder(nn.Module):
right_context = x[-self.right_context_length :] right_context = x[-self.right_context_length :]
utterance = x[: -self.right_context_length] utterance = x[: -self.right_context_length]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
# calcualte padding mask to mask out initial zero caches # calcualte padding mask to mask out initial zero caches
chunk_mask = make_pad_mask(output_lengths).to(x.device) chunk_mask = make_pad_mask(output_lengths).to(x.device)
@ -1611,13 +1511,11 @@ class EmformerEncoder(nn.Module):
( (
output, output,
right_context, right_context,
memory,
output_attn_cache, output_attn_cache,
output_conv_cache, output_conv_cache,
) = layer.infer( ) = layer.infer(
output, output,
right_context, right_context,
memory,
padding_mask=padding_mask, padding_mask=padding_mask,
attn_cache=attn_caches[layer_idx], attn_cache=attn_caches[layer_idx],
conv_cache=conv_caches[layer_idx], conv_cache=conv_caches[layer_idx],