mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
use average value as memory vector for each chunk
This commit is contained in:
parent
1c067e7364
commit
193b44ed7a
@ -537,7 +537,6 @@ class EmformerAttention(nn.Module):
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
right_context: torch.Tensor,
|
||||
summary: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
@ -550,10 +549,8 @@ class EmformerAttention(nn.Module):
|
||||
M = memory.size(0)
|
||||
scaling = float(self.head_dim) ** -0.5
|
||||
|
||||
# compute query with [right_context, utterance, summary].
|
||||
query = self.emb_to_query(
|
||||
torch.cat([right_context, utterance, summary])
|
||||
)
|
||||
# compute query with [right_context, utterance].
|
||||
query = self.emb_to_query(torch.cat([right_context, utterance]))
|
||||
# compute key and value with [memory, right_context, utterance].
|
||||
key, value = self.emb_to_key_value(
|
||||
torch.cat([memory, right_context, utterance])
|
||||
@ -593,26 +590,18 @@ class EmformerAttention(nn.Module):
|
||||
)
|
||||
|
||||
# apply output projection
|
||||
outputs = self.out_proj(attention)
|
||||
output_right_context_utterance = self.out_proj(attention)
|
||||
|
||||
output_right_context_utterance = outputs[: R + U]
|
||||
output_memory = outputs[R + U :]
|
||||
if self.tanh_on_mem:
|
||||
output_memory = torch.tanh(output_memory)
|
||||
else:
|
||||
output_memory = torch.clamp(output_memory, min=-10, max=10)
|
||||
|
||||
return output_right_context_utterance, output_memory, key, value
|
||||
return output_right_context_utterance, key, value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
right_context: torch.Tensor,
|
||||
summary: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
# TODO: Modify docs.
|
||||
"""Forward pass for training and validation mode.
|
||||
|
||||
@ -620,17 +609,16 @@ class EmformerAttention(nn.Module):
|
||||
D: embedding dimension;
|
||||
R: length of the hard-copied right contexts;
|
||||
U: length of full utterance;
|
||||
S: length of summary vectors;
|
||||
M: length of memory vectors.
|
||||
|
||||
It computes a `big` attention matrix on full utterance and
|
||||
then utilizes a pre-computed mask to simulate chunk-wise attention.
|
||||
|
||||
It concatenates three blocks: hard-copied right contexts,
|
||||
full utterance, and summary vectors, as a `big` block,
|
||||
and full utterance, as a `big` block,
|
||||
to compute the query tensor:
|
||||
query = [right_context, utterance, summary],
|
||||
with length Q = R + U + S.
|
||||
query = [right_context, utterance],
|
||||
with length Q = R + U.
|
||||
It concatenates the three blocks: memory vectors,
|
||||
hard-copied right contexts, and full utterance as another `big` block,
|
||||
to compute the key and value tensors:
|
||||
@ -644,10 +632,8 @@ class EmformerAttention(nn.Module):
|
||||
r_i: right context that c_i can use;
|
||||
l_i: left context that c_i can use;
|
||||
m_i: past memory vectors from previous layer that c_i can use;
|
||||
s_i: summary vector of c_i;
|
||||
The target chunk-wise attention is:
|
||||
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
|
||||
s_i (in query) -> l_i, c_i, r_i (in key).
|
||||
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key)
|
||||
|
||||
Args:
|
||||
utterance (torch.Tensor):
|
||||
@ -655,9 +641,6 @@ class EmformerAttention(nn.Module):
|
||||
right_context (torch.Tensor):
|
||||
Hard-copied right context frames, with shape (R, B, D),
|
||||
where R = num_chunks * right_context_length
|
||||
summary (torch.Tensor):
|
||||
Summary elements with shape (S, B, D), where S = num_chunks.
|
||||
It is an empty tensor without using memory.
|
||||
memory (torch.Tensor):
|
||||
Memory elements, with shape (M, B, D), where M = num_chunks - 1.
|
||||
It is an empty tensor without using memory.
|
||||
@ -668,31 +651,22 @@ class EmformerAttention(nn.Module):
|
||||
Padding mask of key tensor, with shape (B, KV).
|
||||
|
||||
Returns:
|
||||
A tuple containing 2 tensors:
|
||||
- output of right context and utterance, with shape (R + U, B, D).
|
||||
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
||||
Output of right context and utterance, with shape (R + U, B, D).
|
||||
"""
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
_,
|
||||
_,
|
||||
) = self._forward_impl(
|
||||
output_right_context_utterance, _, _ = self._forward_impl(
|
||||
utterance,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
attention_mask,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
return output_right_context_utterance, output_memory[:-1]
|
||||
return output_right_context_utterance
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
right_context: torch.Tensor,
|
||||
summary: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
left_context_key: torch.Tensor,
|
||||
left_context_val: torch.Tensor,
|
||||
@ -705,13 +679,12 @@ class EmformerAttention(nn.Module):
|
||||
R: length of right context;
|
||||
U: length of utterance, i.e., current chunk;
|
||||
L: length of cached left context;
|
||||
S: length of summary vectors, S = 1;
|
||||
M: length of cached memory vectors.
|
||||
|
||||
It concatenates the right context, utterance (i.e., current chunk)
|
||||
and summary vector of current chunk, to compute the query tensor:
|
||||
query = [right_context, utterance, summary],
|
||||
with length Q = R + U + S.
|
||||
It concatenates the right context and utterance (i.e., current chunk)
|
||||
of current chunk, to compute the query tensor:
|
||||
query = [right_context, utterance],
|
||||
with length Q = R + U.
|
||||
It concatenates the memory vectors, right context, left context, and
|
||||
current chunk, to compute the key and value tensors:
|
||||
key & value = [memory, right_context, left_context, utterance],
|
||||
@ -719,8 +692,7 @@ class EmformerAttention(nn.Module):
|
||||
|
||||
The chunk-wise attention is:
|
||||
chunk, right context (in query) ->
|
||||
left context, chunk, right context, memory vectors (in key);
|
||||
summary (in query) -> left context, chunk, right context (in key).
|
||||
left context, chunk, right context, memory vectors (in key).
|
||||
|
||||
Args:
|
||||
utterance (torch.Tensor):
|
||||
@ -728,8 +700,6 @@ class EmformerAttention(nn.Module):
|
||||
right_context (torch.Tensor):
|
||||
Right context frames, with shape (R, B, D),
|
||||
where R = right_context_length.
|
||||
summary (torch.Tensor):
|
||||
Summary vector with shape (1, B, D), or empty tensor.
|
||||
memory (torch.Tensor):
|
||||
Memory vectors, with shape (M, B, D), or empty tensor.
|
||||
left_context_key (torch,Tensor):
|
||||
@ -744,7 +714,6 @@ class EmformerAttention(nn.Module):
|
||||
Returns:
|
||||
A tuple containing 4 tensors:
|
||||
- output of right context and utterance, with shape (R + U, B, D).
|
||||
- memory output, with shape (1, B, D) or (0, B, D).
|
||||
- attention key of left context and utterance, which would be cached
|
||||
for next computation, with shape (L + U, B, D).
|
||||
- attention value of left context and utterance, which would be
|
||||
@ -753,28 +722,19 @@ class EmformerAttention(nn.Module):
|
||||
U = utterance.size(0)
|
||||
R = right_context.size(0)
|
||||
L = left_context_key.size(0)
|
||||
S = summary.size(0)
|
||||
M = memory.size(0)
|
||||
|
||||
# TODO: move it outside
|
||||
# query = [right context, utterance, summary]
|
||||
Q = R + U + S
|
||||
# query = [right context, utterance]
|
||||
Q = R + U
|
||||
# key, value = [memory, right context, left context, uttrance]
|
||||
KV = M + R + L + U
|
||||
attention_mask = torch.zeros(Q, KV).to(
|
||||
dtype=torch.bool, device=utterance.device
|
||||
)
|
||||
# disallow attention bettween the summary vector with the memory bank
|
||||
attention_mask[-1, :M] = True
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
key,
|
||||
value,
|
||||
) = self._forward_impl(
|
||||
|
||||
output_right_context_utterance, key, value = self._forward_impl(
|
||||
utterance,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
attention_mask,
|
||||
padding_mask=padding_mask,
|
||||
@ -783,7 +743,6 @@ class EmformerAttention(nn.Module):
|
||||
)
|
||||
return (
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
key[M + R :],
|
||||
value[M + R :],
|
||||
)
|
||||
@ -938,49 +897,46 @@ class EmformerEncoderLayer(nn.Module):
|
||||
self,
|
||||
right_context_utterance: torch.Tensor,
|
||||
R: int,
|
||||
memory: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
"""Apply attention module in training and validation mode."""
|
||||
utterance = right_context_utterance[R:]
|
||||
right_context = right_context_utterance[:R]
|
||||
|
||||
if self.use_memory:
|
||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
memory = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
)
|
||||
)[:-1, :, :]
|
||||
else:
|
||||
summary = torch.empty(0).to(
|
||||
memory = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
output_right_context_utterance, output_memory = self.attention(
|
||||
output_right_context_utterance = self.attention(
|
||||
utterance=utterance,
|
||||
right_context=right_context,
|
||||
summary=summary,
|
||||
memory=memory,
|
||||
attention_mask=attention_mask,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
return output_right_context_utterance, output_memory
|
||||
return output_right_context_utterance
|
||||
|
||||
def _apply_attention_module_infer(
|
||||
self,
|
||||
right_context_utterance: torch.Tensor,
|
||||
R: int,
|
||||
memory: torch.Tensor,
|
||||
attn_cache: List[torch.Tensor],
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Apply attention module in inference mode.
|
||||
1) Unpack cached states including:
|
||||
- memory from previous chunks in the lower layer;
|
||||
- memory from previous chunks;
|
||||
- attention key and value of left context from preceding
|
||||
chunk's compuation;
|
||||
2) Apply attention computation;
|
||||
3) Update cached attention states including:
|
||||
- output memory of current chunk in the lower layer;
|
||||
- memory of current chunk;
|
||||
- attention key and value in current chunk's computation, which would
|
||||
be resued in next chunk's computation.
|
||||
"""
|
||||
@ -992,23 +948,20 @@ class EmformerEncoderLayer(nn.Module):
|
||||
left_context_val = attn_cache[2]
|
||||
|
||||
if self.use_memory:
|
||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
memory = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
)
|
||||
summary = summary[:1]
|
||||
)[:1, :, :]
|
||||
else:
|
||||
summary = torch.empty(0).to(
|
||||
memory = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
next_key,
|
||||
next_val,
|
||||
) = self.attention.infer(
|
||||
utterance=utterance,
|
||||
right_context=right_context,
|
||||
summary=summary,
|
||||
memory=pre_memory,
|
||||
left_context_key=left_context_key,
|
||||
left_context_val=left_context_val,
|
||||
@ -1017,17 +970,16 @@ class EmformerEncoderLayer(nn.Module):
|
||||
attn_cache = self._update_attn_cache(
|
||||
next_key, next_val, memory, attn_cache
|
||||
)
|
||||
return output_right_context_utterance, output_memory, attn_cache
|
||||
return output_right_context_utterance, attn_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
right_context: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""Forward pass for training and validation mode.
|
||||
|
||||
B: batch size;
|
||||
@ -1041,20 +993,16 @@ class EmformerEncoderLayer(nn.Module):
|
||||
Utterance frames, with shape (U, B, D).
|
||||
right_context (torch.Tensor):
|
||||
Right context frames, with shape (R, B, D).
|
||||
memory (torch.Tensor):
|
||||
Memory elements, with shape (M, B, D).
|
||||
It is an empty tensor without using memory.
|
||||
attention_mask (torch.Tensor):
|
||||
Attention mask for underlying attention module,
|
||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||
with shape (Q, KV), where Q = R + U, KV = M + R + U.
|
||||
padding_mask (torch.Tensor):
|
||||
Padding mask of ker tensor, with shape (B, KV).
|
||||
|
||||
Returns:
|
||||
A tuple containing 3 tensors:
|
||||
A tuple containing 2 tensors:
|
||||
- output utterance, with shape (U, B, D).
|
||||
- output right context, with shape (R, B, D).
|
||||
- output memory, with shape (M, B, D).
|
||||
"""
|
||||
R = right_context.size(0)
|
||||
src = torch.cat([right_context, utterance])
|
||||
@ -1076,8 +1024,8 @@ class EmformerEncoderLayer(nn.Module):
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# emformer attention module
|
||||
src_att, output_memory = self._apply_attention_module_forward(
|
||||
src, R, memory, attention_mask, padding_mask=padding_mask
|
||||
src_att = self._apply_attention_module_forward(
|
||||
src, R, attention_mask, padding_mask=padding_mask
|
||||
)
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
@ -1095,24 +1043,17 @@ class EmformerEncoderLayer(nn.Module):
|
||||
|
||||
output_utterance = src[R:]
|
||||
output_right_context = src[:R]
|
||||
return output_utterance, output_right_context, output_memory
|
||||
return output_utterance, output_right_context
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
right_context: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
attn_cache: List[torch.Tensor],
|
||||
conv_cache: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
List[torch.Tensor],
|
||||
torch.Tensor,
|
||||
]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||
"""Forward pass for inference.
|
||||
|
||||
B: batch size;
|
||||
@ -1126,8 +1067,6 @@ class EmformerEncoderLayer(nn.Module):
|
||||
Utterance frames, with shape (U, B, D).
|
||||
right_context (torch.Tensor):
|
||||
Right context frames, with shape (R, B, D).
|
||||
memory (torch.Tensor):
|
||||
Memory elements, with shape (M, B, D).
|
||||
attn_cache (List[torch.Tensor]):
|
||||
Cached attention tensors generated in preceding computation,
|
||||
including memory, key and value of left context.
|
||||
@ -1140,9 +1079,8 @@ class EmformerEncoderLayer(nn.Module):
|
||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||
- output utterance, with shape (U, B, D);
|
||||
- output right_context, with shape (R, B, D);
|
||||
- output memory, with shape (1, B, D) or (0, B, D).
|
||||
- output state.
|
||||
- updated conv_cache.
|
||||
- output attention cache;
|
||||
- output convolution cache.
|
||||
"""
|
||||
R = right_context.size(0)
|
||||
src = torch.cat([right_context, utterance])
|
||||
@ -1151,12 +1089,8 @@ class EmformerEncoderLayer(nn.Module):
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# emformer attention module
|
||||
(
|
||||
src_att,
|
||||
output_memory,
|
||||
attn_cache,
|
||||
) = self._apply_attention_module_infer(
|
||||
src, R, memory, attn_cache, padding_mask=padding_mask
|
||||
src_att, attn_cache = self._apply_attention_module_infer(
|
||||
src, R, attn_cache, padding_mask=padding_mask
|
||||
)
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
@ -1174,7 +1108,6 @@ class EmformerEncoderLayer(nn.Module):
|
||||
return (
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
output_memory,
|
||||
attn_cache,
|
||||
conv_cache,
|
||||
)
|
||||
@ -1253,11 +1186,6 @@ class EmformerEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.use_memory = memory_size > 0
|
||||
self.init_memory_op = nn.AvgPool1d(
|
||||
kernel_size=chunk_length,
|
||||
stride=chunk_length,
|
||||
ceil_mode=True,
|
||||
)
|
||||
|
||||
self.emformer_layers = nn.ModuleList(
|
||||
[
|
||||
@ -1358,16 +1286,15 @@ class EmformerEncoder(nn.Module):
|
||||
|
||||
R: length of hard-copied right contexts;
|
||||
U: length of full utterance;
|
||||
S: length of summary vectors;
|
||||
M: length of memory vectors;
|
||||
Q: length of attention query;
|
||||
KV: length of attention key and value.
|
||||
|
||||
The shape of attention mask is (Q, KV).
|
||||
If self.use_memory is `True`:
|
||||
query = [right_context, utterance, summary];
|
||||
query = [right_context, utterance];
|
||||
key, value = [memory, right_context, utterance];
|
||||
Q = R + U + S, KV = M + R + U.
|
||||
Q = R + U, KV = M + R + U.
|
||||
Otherwise:
|
||||
query = [right_context, utterance]
|
||||
key, value = [right_context, utterance]
|
||||
@ -1378,17 +1305,14 @@ class EmformerEncoder(nn.Module):
|
||||
r_i: right context that c_i can use;
|
||||
l_i: left context that c_i can use;
|
||||
m_i: past memory vectors from previous layer that c_i can use;
|
||||
s_i: summary vector of c_i.
|
||||
The target chunk-wise attention is:
|
||||
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
|
||||
s_i (in query) -> l_i, c_i, r_i (in key).
|
||||
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key).
|
||||
"""
|
||||
U = utterance.size(0)
|
||||
num_chunks = math.ceil(U / self.chunk_length)
|
||||
|
||||
right_context_mask = []
|
||||
utterance_mask = []
|
||||
summary_mask = []
|
||||
|
||||
if self.use_memory:
|
||||
num_cols = 9
|
||||
@ -1397,9 +1321,6 @@ class EmformerEncoder(nn.Module):
|
||||
right_context_utterance_cols_mask = [
|
||||
idx in [1, 4, 7] for idx in range(num_cols)
|
||||
]
|
||||
# summary attends to right context, utterance
|
||||
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
||||
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
|
||||
else:
|
||||
num_cols = 6
|
||||
# right context and utterance both attend to right context and
|
||||
@ -1407,8 +1328,7 @@ class EmformerEncoder(nn.Module):
|
||||
right_context_utterance_cols_mask = [
|
||||
idx in [1, 4] for idx in range(num_cols)
|
||||
]
|
||||
summary_cols_mask = None
|
||||
masks_to_concat = [right_context_mask, utterance_mask]
|
||||
masks_to_concat = [right_context_mask, utterance_mask]
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
col_widths = self._gen_attention_mask_col_widths(chunk_idx, U)
|
||||
@ -1432,12 +1352,6 @@ class EmformerEncoder(nn.Module):
|
||||
)
|
||||
utterance_mask.append(utterance_mask_block)
|
||||
|
||||
if summary_cols_mask is not None:
|
||||
summary_mask_block = _gen_attention_mask_block(
|
||||
col_widths, summary_cols_mask, 1, utterance.device
|
||||
)
|
||||
summary_mask.append(summary_mask_block)
|
||||
|
||||
attention_mask = (
|
||||
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
|
||||
).to(torch.bool)
|
||||
@ -1473,23 +1387,15 @@ class EmformerEncoder(nn.Module):
|
||||
utterance = x[:U]
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
attention_mask = self._gen_attention_mask(utterance)
|
||||
memory = (
|
||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||
:-1
|
||||
]
|
||||
if self.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
padding_mask = make_pad_mask(
|
||||
memory.size(0) + right_context.size(0) + output_lengths
|
||||
)
|
||||
|
||||
M = right_context.size(0) // self.chunk_length - 1
|
||||
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
|
||||
|
||||
output = utterance
|
||||
for layer in self.emformer_layers:
|
||||
output, right_context, memory = layer(
|
||||
output, right_context = layer(
|
||||
output,
|
||||
right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
padding_mask=padding_mask,
|
||||
warmup=warmup,
|
||||
@ -1525,7 +1431,6 @@ class EmformerEncoder(nn.Module):
|
||||
right_context at the end.
|
||||
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
|
||||
Cached states containing:
|
||||
- past_lens: number of past frames for each sample in batch
|
||||
- attn_caches: attention states from preceding chunk's computation,
|
||||
where each element corresponds to each emformer layer
|
||||
- conv_caches: left context for causal convolution, where each
|
||||
@ -1571,11 +1476,6 @@ class EmformerEncoder(nn.Module):
|
||||
right_context = x[-self.right_context_length :]
|
||||
utterance = x[: -self.right_context_length]
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
memory = (
|
||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
||||
if self.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
|
||||
# calcualte padding mask to mask out initial zero caches
|
||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
@ -1611,13 +1511,11 @@ class EmformerEncoder(nn.Module):
|
||||
(
|
||||
output,
|
||||
right_context,
|
||||
memory,
|
||||
output_attn_cache,
|
||||
output_conv_cache,
|
||||
) = layer.infer(
|
||||
output,
|
||||
right_context,
|
||||
memory,
|
||||
padding_mask=padding_mask,
|
||||
attn_cache=attn_caches[layer_idx],
|
||||
conv_cache=conv_caches[layer_idx],
|
||||
|
Loading…
x
Reference in New Issue
Block a user