refactor, use fixed-length cache for batch decoding

This commit is contained in:
yaozengwei 2022-06-06 21:19:25 +08:00
parent 10998bef69
commit 13899dff51

View File

@ -200,7 +200,6 @@ class ConvolutionModule(nn.Module):
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
cache: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Causal convolution module. """Causal convolution module.
@ -209,14 +208,11 @@ class ConvolutionModule(nn.Module):
Utterance tensor of shape (U, B, D). Utterance tensor of shape (U, B, D).
right_context (torch.Tensor): right_context (torch.Tensor):
Right context tensor of shape (R, B, D). Right context tensor of shape (R, B, D).
cache (torch.Tensor, optional):
Cached tensor for left padding of shape (B, D, cache_size).
Returns: Returns:
A tuple of 3 tensors: A tuple of 2 tensors:
- output utterance of shape (U, B, D). - output utterance of shape (U, B, D).
- output right_context of shape (R, B, D). - output right_context of shape (R, B, D).
- updated cache tensor of shape (B, D, cache_size).
""" """
U, B, D = utterance.size() U, B, D = utterance.size()
R, _, _ = right_context.size() R, _, _ = right_context.size()
@ -230,17 +226,13 @@ class ConvolutionModule(nn.Module):
utterance = x[:, :, R:] # (B, D, U) utterance = x[:, :, R:] # (B, D, U)
right_context = x[:, :, :R] # (B, D, R) right_context = x[:, :, :R] # (B, D, R)
if cache is None: # make causal convolution
cache = torch.zeros( cache = torch.zeros(
B, D, self.cache_size, device=x.device, dtype=x.dtype B, D, self.cache_size, device=x.device, dtype=x.dtype
) )
else:
assert cache.shape == (B, D, self.cache_size), cache.shape
pad_utterance = torch.cat( pad_utterance = torch.cat(
[cache, utterance], dim=2 [cache, utterance], dim=2
) # (B, D, cache + U) ) # (B, D, cache + U)
# update cache
new_cache = pad_utterance[:, :, -self.cache_size :]
# depth-wise conv on utterance # depth-wise conv on utterance
utterance = self.depthwise_conv(pad_utterance) # (B, D, U) utterance = self.depthwise_conv(pad_utterance) # (B, D, U)
@ -269,7 +261,6 @@ class ConvolutionModule(nn.Module):
return ( return (
utterance.permute(2, 0, 1), utterance.permute(2, 0, 1),
right_context.permute(2, 0, 1), right_context.permute(2, 0, 1),
new_cache,
) )
def infer( def infer(
@ -304,11 +295,7 @@ class ConvolutionModule(nn.Module):
x = self.deriv_balancer1(x) x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=1) # (B, D, U + R) x = nn.functional.glu(x, dim=1) # (B, D, U + R)
if cache is None: # make causal convolution
cache = torch.zeros(
B, D, self.cache_size, device=x.device, dtype=x.dtype
)
else:
assert cache.shape == (B, D, self.cache_size), cache.shape assert cache.shape == (B, D, self.cache_size), cache.shape
x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R) x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R)
# update cache # update cache
@ -383,7 +370,7 @@ class EmformerAttention(nn.Module):
self, self,
attention_weights: torch.Tensor, attention_weights: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor], padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Given the entire attention weights, mask out unecessary connections """Given the entire attention weights, mask out unecessary connections
and optionally with padding positions, to obtain underlying chunk-wise and optionally with padding positions, to obtain underlying chunk-wise
@ -438,11 +425,11 @@ class EmformerAttention(nn.Module):
def _forward_impl( def _forward_impl(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
lengths: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
summary: torch.Tensor, summary: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
left_context_key: Optional[torch.Tensor] = None, left_context_key: Optional[torch.Tensor] = None,
left_context_val: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
@ -470,7 +457,7 @@ class EmformerAttention(nn.Module):
[value[: M + R], left_context_val, value[M + R :]] [value[: M + R], left_context_val, value[M + R :]]
) )
Q = query.size(0) Q = query.size(0)
KV = key.size(0) # KV = key.size(0)
reshaped_query, reshaped_key, reshaped_value = [ reshaped_query, reshaped_key, reshaped_value = [
tensor.contiguous() tensor.contiguous()
@ -482,12 +469,6 @@ class EmformerAttention(nn.Module):
reshaped_query * scaling, reshaped_key.transpose(1, 2) reshaped_query * scaling, reshaped_key.transpose(1, 2)
) # (B * nhead, Q, KV) ) # (B * nhead, Q, KV)
# compute padding mask
if B == 1:
padding_mask = None
else:
padding_mask = make_pad_mask(KV - U + lengths)
# compute attention probabilities # compute attention probabilities
attention_probs = self._gen_attention_probs( attention_probs = self._gen_attention_probs(
attention_weights, attention_mask, padding_mask attention_weights, attention_mask, padding_mask
@ -515,11 +496,11 @@ class EmformerAttention(nn.Module):
def forward( def forward(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
lengths: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
summary: torch.Tensor, summary: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: Modify docs. # TODO: Modify docs.
"""Forward pass for training and validation mode. """Forward pass for training and validation mode.
@ -560,9 +541,6 @@ class EmformerAttention(nn.Module):
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
Full utterance frames, with shape (U, B, D). Full utterance frames, with shape (U, B, D).
lengths (torch.Tensor):
With shape (B,) and i-th element representing
number of valid frames for i-th batch element in utterance.
right_context (torch.Tensor): right_context (torch.Tensor):
Hard-copied right context frames, with shape (R, B, D), Hard-copied right context frames, with shape (R, B, D),
where R = num_chunks * right_context_length where R = num_chunks * right_context_length
@ -575,6 +553,8 @@ class EmformerAttention(nn.Module):
attention_mask (torch.Tensor): attention_mask (torch.Tensor):
Pre-computed attention mask to simulate underlying chunk-wise Pre-computed attention mask to simulate underlying chunk-wise
attention, with shape (Q, KV). attention, with shape (Q, KV).
padding_mask (torch.Tensor):
Padding mask of key tensor, with shape (B, KV).
Returns: Returns:
A tuple containing 2 tensors: A tuple containing 2 tensors:
@ -588,23 +568,23 @@ class EmformerAttention(nn.Module):
_, _,
) = self._forward_impl( ) = self._forward_impl(
utterance, utterance,
lengths,
right_context, right_context,
summary, summary,
memory, memory,
attention_mask, attention_mask,
padding_mask=padding_mask,
) )
return output_right_context_utterance, output_memory[:-1] return output_right_context_utterance, output_memory[:-1]
def infer( def infer(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
lengths: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
summary: torch.Tensor, summary: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
left_context_key: torch.Tensor, left_context_key: torch.Tensor,
left_context_val: torch.Tensor, left_context_val: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass for inference. """Forward pass for inference.
@ -633,9 +613,6 @@ class EmformerAttention(nn.Module):
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
Current chunk frames, with shape (U, B, D), where U = chunk_length. Current chunk frames, with shape (U, B, D), where U = chunk_length.
lengths (torch.Tensor):
With shape (B,) and i-th element representing
number of valid frames for i-th batch element in utterance.
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D), Right context frames, with shape (R, B, D),
where R = right_context_length. where R = right_context_length.
@ -645,10 +622,12 @@ class EmformerAttention(nn.Module):
Memory vectors, with shape (M, B, D), or empty tensor. Memory vectors, with shape (M, B, D), or empty tensor.
left_context_key (torch,Tensor): left_context_key (torch,Tensor):
Cached attention key of left context from preceding computation, Cached attention key of left context from preceding computation,
with shape (L, B, D), where L <= left_context_length. with shape (L, B, D).
left_context_val (torch.Tensor): left_context_val (torch.Tensor):
Cached attention value of left context from preceding computation, Cached attention value of left context from preceding computation,
with shape (L, B, D), where L <= left_context_length. with shape (L, B, D).
padding_mask (torch.Tensor):
Padding mask of key tensor, with shape (B, KV).
Returns: Returns:
A tuple containing 4 tensors: A tuple containing 4 tensors:
@ -665,6 +644,7 @@ class EmformerAttention(nn.Module):
S = summary.size(0) S = summary.size(0)
M = memory.size(0) M = memory.size(0)
# TODO: move it outside
# query = [right context, utterance, summary] # query = [right context, utterance, summary]
Q = R + U + S Q = R + U + S
# key, value = [memory, right context, left context, uttrance] # key, value = [memory, right context, left context, uttrance]
@ -681,11 +661,11 @@ class EmformerAttention(nn.Module):
value, value,
) = self._forward_impl( ) = self._forward_impl(
utterance, utterance,
lengths,
right_context, right_context,
summary, summary,
memory, memory,
attention_mask, attention_mask,
padding_mask=padding_mask,
left_context_key=left_context_key, left_context_key=left_context_key,
left_context_val=left_context_val, left_context_val=left_context_val,
) )
@ -719,8 +699,8 @@ class EmformerEncoderLayer(nn.Module):
Length of left context. (Default: 0) Length of left context. (Default: 0)
right_context_length (int, optional): right_context_length (int, optional):
Length of right context. (Default: 0) Length of right context. (Default: 0)
max_memory_size (int, optional): memory_size (int, optional):
Maximum number of memory elements to use. (Default: 0) Number of memory elements to use. (Default: 0)
tanh_on_mem (bool, optional): tanh_on_mem (bool, optional):
If ``True``, applies tanh to memory elements. (Default: ``False``) If ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional): negative_inf (float, optional):
@ -738,7 +718,7 @@ class EmformerEncoderLayer(nn.Module):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
left_context_length: int = 0, left_context_length: int = 0,
right_context_length: int = 0, right_context_length: int = 0,
max_memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
): ):
@ -791,75 +771,29 @@ class EmformerEncoderLayer(nn.Module):
self.layer_dropout = layer_dropout self.layer_dropout = layer_dropout
self.left_context_length = left_context_length self.left_context_length = left_context_length
self.chunk_length = chunk_length self.chunk_length = chunk_length
self.max_memory_size = max_memory_size self.memory_size = memory_size
self.d_model = d_model self.d_model = d_model
self.use_memory = max_memory_size > 0 self.use_memory = memory_size > 0
def _init_state( def _update_attn_cache(
self, batch_size: int, device: Optional[torch.device]
) -> List[torch.Tensor]:
"""Initialize states with zeros."""
empty_memory = torch.zeros(
self.max_memory_size, batch_size, self.d_model, device=device
)
left_context_key = torch.zeros(
self.left_context_length, batch_size, self.d_model, device=device
)
left_context_val = torch.zeros(
self.left_context_length, batch_size, self.d_model, device=device
)
past_length = torch.zeros(
1, batch_size, dtype=torch.int32, device=device
)
return [empty_memory, left_context_key, left_context_val, past_length]
def _unpack_state(
self, state: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Unpack cached states including:
1) output memory from previous chunks in the lower layer;
2) attention key and value of left context from preceding chunk's
computation.
"""
past_length = state[3][0][0].item()
past_left_context_length = min(self.left_context_length, past_length)
past_memory_length = min(
self.max_memory_size, math.ceil(past_length / self.chunk_length)
)
memory_start_idx = self.max_memory_size - past_memory_length
pre_memory = state[0][memory_start_idx:]
left_context_start_idx = (
self.left_context_length - past_left_context_length
)
left_context_key = state[1][left_context_start_idx:]
left_context_val = state[2][left_context_start_idx:]
return pre_memory, left_context_key, left_context_val
def _pack_state(
self, self,
next_key: torch.Tensor, next_key: torch.Tensor,
next_val: torch.Tensor, next_val: torch.Tensor,
update_length: int,
memory: torch.Tensor, memory: torch.Tensor,
state: List[torch.Tensor], attn_cache: List[torch.Tensor],
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Pack updated states including: """Update cached attention state:
1) output memory of current chunk in the lower layer; 1) output memory of current chunk in the lower layer;
2) attention key and value in current chunk's computation, which would 2) attention key and value in current chunk's computation, which would
be resued in next chunk's computation. be resued in next chunk's computation.
3) length of current chunk.
""" """
new_memory = torch.cat([state[0], memory]) new_memory = torch.cat([attn_cache[0], memory])
new_key = torch.cat([state[1], next_key]) new_key = torch.cat([attn_cache[1], next_key])
new_val = torch.cat([state[2], next_val]) new_val = torch.cat([attn_cache[2], next_val])
memory_start_idx = new_memory.size(0) - self.max_memory_size attn_cache[0] = new_memory[new_memory.size(0) - self.memory_size :]
state[0] = new_memory[memory_start_idx:] attn_cache[1] = new_key[new_key.size(0) - self.left_context_length :]
key_start_idx = new_key.size(0) - self.left_context_length attn_cache[2] = new_val[new_val.size(0) - self.left_context_length :]
state[1] = new_key[key_start_idx:] return attn_cache
val_start_idx = new_val.size(0) - self.left_context_length
state[2] = new_val[val_start_idx:]
state[3] = state[3] + update_length
return state
def _apply_conv_module_forward( def _apply_conv_module_forward(
self, self,
@ -869,7 +803,7 @@ class EmformerEncoderLayer(nn.Module):
"""Apply convolution module in training and validation mode.""" """Apply convolution module in training and validation mode."""
utterance = right_context_utterance[R:] utterance = right_context_utterance[R:]
right_context = right_context_utterance[:R] right_context = right_context_utterance[:R]
utterance, right_context, _ = self.conv_module(utterance, right_context) utterance, right_context = self.conv_module(utterance, right_context)
right_context_utterance = torch.cat([right_context, utterance]) right_context_utterance = torch.cat([right_context, utterance])
return right_context_utterance return right_context_utterance
@ -892,15 +826,11 @@ class EmformerEncoderLayer(nn.Module):
self, self,
right_context_utterance: torch.Tensor, right_context_utterance: torch.Tensor,
R: int, R: int,
lengths: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply attention module in training and validation mode.""" """Apply attention module in training and validation mode."""
if attention_mask is None:
raise ValueError(
"attention_mask must be not None in training or validation mode." # noqa
)
utterance = right_context_utterance[R:] utterance = right_context_utterance[R:]
right_context = right_context_utterance[:R] right_context = right_context_utterance[:R]
@ -914,11 +844,11 @@ class EmformerEncoderLayer(nn.Module):
) )
output_right_context_utterance, output_memory = self.attention( output_right_context_utterance, output_memory = self.attention(
utterance=utterance, utterance=utterance,
lengths=lengths,
right_context=right_context, right_context=right_context,
summary=summary, summary=summary,
memory=memory, memory=memory,
attention_mask=attention_mask, attention_mask=attention_mask,
padding_mask=padding_mask,
) )
return output_right_context_utterance, output_memory return output_right_context_utterance, output_memory
@ -927,9 +857,9 @@ class EmformerEncoderLayer(nn.Module):
self, self,
right_context_utterance: torch.Tensor, right_context_utterance: torch.Tensor,
R: int, R: int,
lengths: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
state: Optional[List[torch.Tensor]] = None, attn_cache: List[torch.Tensor],
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""Apply attention module in inference mode. """Apply attention module in inference mode.
1) Unpack cached states including: 1) Unpack cached states including:
@ -937,7 +867,7 @@ class EmformerEncoderLayer(nn.Module):
- attention key and value of left context from preceding - attention key and value of left context from preceding
chunk's compuation; chunk's compuation;
2) Apply attention computation; 2) Apply attention computation;
3) Pack updated states including: 3) Update cached attention states including:
- output memory of current chunk in the lower layer; - output memory of current chunk in the lower layer;
- attention key and value in current chunk's computation, which would - attention key and value in current chunk's computation, which would
be resued in next chunk's computation. be resued in next chunk's computation.
@ -946,11 +876,10 @@ class EmformerEncoderLayer(nn.Module):
utterance = right_context_utterance[R:] utterance = right_context_utterance[R:]
right_context = right_context_utterance[:R] right_context = right_context_utterance[:R]
if state is None: pre_memory = attn_cache[0]
state = self._init_state(utterance.size(1), device=utterance.device) left_context_key = attn_cache[1]
pre_memory, left_context_key, left_context_val = self._unpack_state( left_context_val = attn_cache[2]
state
)
if self.use_memory: if self.use_memory:
summary = self.summary_op(utterance.permute(1, 2, 0)).permute( summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1 2, 0, 1
@ -967,25 +896,25 @@ class EmformerEncoderLayer(nn.Module):
next_val, next_val,
) = self.attention.infer( ) = self.attention.infer(
utterance=utterance, utterance=utterance,
lengths=lengths,
right_context=right_context, right_context=right_context,
summary=summary, summary=summary,
memory=pre_memory, memory=pre_memory,
left_context_key=left_context_key, left_context_key=left_context_key,
left_context_val=left_context_val, left_context_val=left_context_val,
padding_mask=padding_mask,
) )
state = self._pack_state( attn_cache = self._update_attn_cache(
next_key, next_val, utterance.size(0), memory, state next_key, next_val, memory, attn_cache
) )
return output_right_context_utterance, output_memory, state return output_right_context_utterance, output_memory, attn_cache
def forward( def forward(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
lengths: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Forward pass for training and validation mode. r"""Forward pass for training and validation mode.
@ -999,9 +928,6 @@ class EmformerEncoderLayer(nn.Module):
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
Utterance frames, with shape (U, B, D). Utterance frames, with shape (U, B, D).
lengths (torch.Tensor):
With shape (B,) and i-th element representing
number of valid frames for i-th batch element in utterance.
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D). Right context frames, with shape (R, B, D).
memory (torch.Tensor): memory (torch.Tensor):
@ -1010,6 +936,8 @@ class EmformerEncoderLayer(nn.Module):
attention_mask (torch.Tensor): attention_mask (torch.Tensor):
Attention mask for underlying attention module, Attention mask for underlying attention module,
with shape (Q, KV), where Q = R + U + S, KV = M + R + U. with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
padding_mask (torch.Tensor):
Padding mask of ker tensor, with shape (B, KV).
Returns: Returns:
A tuple containing 3 tensors: A tuple containing 3 tensors:
@ -1038,7 +966,7 @@ class EmformerEncoderLayer(nn.Module):
# emformer attention module # emformer attention module
src_att, output_memory = self._apply_attention_module_forward( src_att, output_memory = self._apply_attention_module_forward(
src, R, lengths, memory, attention_mask src, R, memory, attention_mask, padding_mask=padding_mask
) )
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
@ -1061,11 +989,11 @@ class EmformerEncoderLayer(nn.Module):
def infer( def infer(
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
lengths: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
state: Optional[List[torch.Tensor]] = None, attn_cache: List[torch.Tensor],
conv_cache: Optional[torch.Tensor] = None, conv_cache: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""Forward pass for inference. """Forward pass for inference.
@ -1078,18 +1006,17 @@ class EmformerEncoderLayer(nn.Module):
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
Utterance frames, with shape (U, B, D). Utterance frames, with shape (U, B, D).
lengths (torch.Tensor):
With shape (B,) and i-th element representing
number of valid frames for i-th batch element in utterance.
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D). Right context frames, with shape (R, B, D).
memory (torch.Tensor): memory (torch.Tensor):
Memory elements, with shape (M, B, D). Memory elements, with shape (M, B, D).
state (List[torch.Tensor], optional): attn_cache (List[torch.Tensor]):
List of tensors representing layer internal state generated in Cached attention tensors generated in preceding computation,
preceding computation. (default=None) including memory, key and value of left context.
conv_cache (torch.Tensor, optional): conv_cache (torch.Tensor, optional):
Cache tensor of left context for causal convolution. Cache tensor of left context for causal convolution.
padding_mask (torch.Tensor):
Padding mask of ker tensor.
Returns: Returns:
(Tensor, Tensor, List[torch.Tensor], Tensor): (Tensor, Tensor, List[torch.Tensor], Tensor):
@ -1109,8 +1036,10 @@ class EmformerEncoderLayer(nn.Module):
( (
src_att, src_att,
output_memory, output_memory,
output_state, attn_cache,
) = self._apply_attention_module_infer(src, R, lengths, memory, state) ) = self._apply_attention_module_infer(
src, R, memory, attn_cache, padding_mask=padding_mask
)
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
# convolution module # convolution module
@ -1128,7 +1057,7 @@ class EmformerEncoderLayer(nn.Module):
output_utterance, output_utterance,
output_right_context, output_right_context,
output_memory, output_memory,
output_state, attn_cache,
conv_cache, conv_cache,
) )
@ -1179,8 +1108,8 @@ class EmformerEncoder(nn.Module):
Length of left context. (default: 0) Length of left context. (default: 0)
right_context_length (int, optional): right_context_length (int, optional):
Length of right context. (default: 0) Length of right context. (default: 0)
max_memory_size (int, optional): memory_size (int, optional):
Maximum number of memory elements to use. (default: 0) Number of memory elements to use. (default: 0)
tanh_on_mem (bool, optional): tanh_on_mem (bool, optional):
If ``true``, applies tanh to memory elements. (default: ``false``) If ``true``, applies tanh to memory elements. (default: ``false``)
negative_inf (float, optional): negative_inf (float, optional):
@ -1199,13 +1128,13 @@ class EmformerEncoder(nn.Module):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
left_context_length: int = 0, left_context_length: int = 0,
right_context_length: int = 0, right_context_length: int = 0,
max_memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
): ):
super().__init__() super().__init__()
self.use_memory = max_memory_size > 0 self.use_memory = memory_size > 0
self.init_memory_op = nn.AvgPool1d( self.init_memory_op = nn.AvgPool1d(
kernel_size=chunk_length, kernel_size=chunk_length,
stride=chunk_length, stride=chunk_length,
@ -1224,7 +1153,7 @@ class EmformerEncoder(nn.Module):
cnn_module_kernel=cnn_module_kernel, cnn_module_kernel=cnn_module_kernel,
left_context_length=left_context_length, left_context_length=left_context_length,
right_context_length=right_context_length, right_context_length=right_context_length,
max_memory_size=max_memory_size, memory_size=memory_size,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
) )
@ -1232,10 +1161,13 @@ class EmformerEncoder(nn.Module):
] ]
) )
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
self.left_context_length = left_context_length self.left_context_length = left_context_length
self.right_context_length = right_context_length self.right_context_length = right_context_length
self.chunk_length = chunk_length self.chunk_length = chunk_length
self.max_memory_size = max_memory_size self.memory_size = memory_size
self.cnn_module_kernel = cnn_module_kernel
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
"""Hard copy each chunk's right context and concat them.""" """Hard copy each chunk's right context and concat them."""
@ -1276,7 +1208,7 @@ class EmformerEncoder(nn.Module):
R = rc * num_chunks R = rc * num_chunks
if self.use_memory: if self.use_memory:
m_start = max(chunk_idx - self.max_memory_size, 0) m_start = max(chunk_idx - self.memory_size, 0)
M = num_chunks - 1 M = num_chunks - 1
col_widths = [ col_widths = [
m_start, # before memory m_start, # before memory
@ -1430,15 +1362,18 @@ class EmformerEncoder(nn.Module):
if self.use_memory if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device) else torch.empty(0).to(dtype=x.dtype, device=x.device)
) )
padding_mask = make_pad_mask(
memory.size(0) + right_context.size(0) + output_lengths
)
output = utterance output = utterance
for layer in self.emformer_layers: for layer in self.emformer_layers:
output, right_context, memory = layer( output, right_context, memory = layer(
output, output,
output_lengths,
right_context, right_context,
memory, memory,
attention_mask, attention_mask,
padding_mask=padding_mask,
warmup=warmup, warmup=warmup,
) )
@ -1448,10 +1383,13 @@ class EmformerEncoder(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
lengths: torch.Tensor, lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None, states: List[
conv_caches: Optional[List[torch.Tensor]] = None, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
],
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor] torch.Tensor,
torch.Tensor,
List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]],
]: ]:
"""Forward pass for streaming inference. """Forward pass for streaming inference.
@ -1467,13 +1405,13 @@ class EmformerEncoder(nn.Module):
With shape (B,) and i-th element representing number of valid With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the utterance frames for i-th batch element in x, which contains the
right_context at the end. right_context at the end.
states (List[List[torch.Tensor]], optional): states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
Cached states from preceding chunk's computation, where each Cached states containing:
element (List[torch.Tensor]) corresponds to each emformer layer. - past_lens: number of past frames for each sample in batch
(default: None) - attn_caches: attention states from preceding chunk's computation,
conv_caches (List[torch.Tensor], optional): where each element corresponds to each emformer layer
Cached tensors of left context for causal convolution, where each - conv_caches: left context for causal convolution, where each
element (Tensor) corresponds to each convolutional layer. element corresponds to each layer.
Returns: Returns:
(Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]): (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
@ -1481,8 +1419,38 @@ class EmformerEncoder(nn.Module):
- output lengths, with shape (B,), without containing the - output lengths, with shape (B,), without containing the
right_context at the end. right_context at the end.
- updated states from current chunk's computation. - updated states from current chunk's computation.
- updated convolution caches from current chunk.
""" """
past_lens = states[0]
assert past_lens.shape == (x.size(1),), past_lens.shape
attn_caches = states[1]
assert len(attn_caches) == self.num_encoder_layers, len(attn_caches)
for i in range(len(attn_caches)):
assert attn_caches[i][0].shape == (
self.memory_size,
x.size(1),
self.d_model,
), attn_caches[i][0].shape
assert attn_caches[i][1].shape == (
self.left_context_length,
x.size(1),
self.d_model,
), attn_caches[i][1].shape
assert attn_caches[i][2].shape == (
self.left_context_length,
x.size(1),
self.d_model,
), attn_caches[i][2].shape
conv_caches = states[2]
assert len(conv_caches) == self.num_encoder_layers, len(conv_caches)
for i in range(len(conv_caches)):
assert conv_caches[i].shape == (
x.size(1),
self.d_model,
self.cnn_module_kernel,
), conv_caches[i].shape
assert x.size(0) == self.chunk_length + self.right_context_length, ( assert x.size(0) == self.chunk_length + self.right_context_length, (
"Per configured chunk_length and right_context_length, " "Per configured chunk_length and right_context_length, "
f"expected size of {self.chunk_length + self.right_context_length} " f"expected size of {self.chunk_length + self.right_context_length} "
@ -1498,28 +1466,60 @@ class EmformerEncoder(nn.Module):
if self.use_memory if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device) else torch.empty(0).to(dtype=x.dtype, device=x.device)
) )
# calcualte padding mask
chunk_mask = make_pad_mask(output_lengths)
memory_mask = (
(past_lens // self.chunk_length).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size
)
).flip(1)
left_context_mask = (
past_lens.view(x.size(1), 1)
<= torch.arange(self.left_context_length, device=x.device).expand(
x.size(1), self.left_context_length
)
).flip(1)
right_context_mask = torch.zeros(
x.size(1),
self.right_context_length,
dtype=torch.bool,
device=x.device,
)
padding_mask = torch.cat(
[memory_mask, left_context_mask, right_context_mask, chunk_mask],
dim=1,
)
output = utterance output = utterance
output_states: List[List[torch.Tensor]] = [] output_attn_caches: List[List[torch.Tensor]] = []
output_conv_caches: List[torch.Tensor] = [] output_conv_caches: List[torch.Tensor] = []
for layer_idx, layer in enumerate(self.emformer_layers): for layer_idx, layer in enumerate(self.emformer_layers):
( (
output, output,
right_context, right_context,
memory, memory,
output_state, output_attn_cache,
output_conv_cache, output_conv_cache,
) = layer.infer( ) = layer.infer(
output, output,
output_lengths,
right_context, right_context,
memory, memory,
None if states is None else states[layer_idx], padding_mask=padding_mask,
None if conv_caches is None else conv_caches[layer_idx], attn_cache=attn_caches[layer_idx],
conv_cache=conv_caches[layer_idx],
) )
output_states.append(output_state) output_attn_caches.append(output_attn_cache)
output_conv_caches.append(output_conv_cache) output_conv_caches.append(output_conv_cache)
return output, output_lengths, output_states, output_conv_caches output_past_lens = past_lens + output_lengths
output_states = [
output_past_lens,
output_attn_caches,
output_conv_caches,
]
return output, output_lengths, output_states
class Emformer(EncoderInterface): class Emformer(EncoderInterface):
@ -1537,7 +1537,7 @@ class Emformer(EncoderInterface):
cnn_module_kernel: int = 3, cnn_module_kernel: int = 3,
left_context_length: int = 0, left_context_length: int = 0,
right_context_length: int = 0, right_context_length: int = 0,
max_memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
): ):
@ -1576,7 +1576,7 @@ class Emformer(EncoderInterface):
cnn_module_kernel=cnn_module_kernel, cnn_module_kernel=cnn_module_kernel,
left_context_length=left_context_length // 4, left_context_length=left_context_length // 4,
right_context_length=right_context_length // 4, right_context_length=right_context_length // 4,
max_memory_size=max_memory_size, memory_size=memory_size,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
) )
@ -1633,7 +1633,6 @@ class Emformer(EncoderInterface):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None, states: Optional[List[List[torch.Tensor]]] = None,
conv_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
"""Forward pass for streaming inference. """Forward pass for streaming inference.
@ -1649,13 +1648,13 @@ class Emformer(EncoderInterface):
With shape (B,) and i-th element representing number of valid With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, containing the utterance frames for i-th batch element in x, containing the
right_context at the end. right_context at the end.
states (List[List[torch.Tensor]], optional): states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
Cached states from preceding chunk's computation, where each Cached states containing:
element (List[torch.Tensor]) corresponds to each emformer layer. - past_lens: number of past frames for each sample in batch
(default: None) - attn_caches: attention states from preceding chunk's computation,
conv_caches (List[torch.Tensor], optional): where each element corresponds to each emformer layer
Cached tensors of left context for causal convolution, where each - conv_caches: left context for causal convolution, where each
element (Tensor) corresponds to each convolutional layer. element corresponds to each layer.
Returns: Returns:
(Tensor, Tensor): (Tensor, Tensor):
- output embedding, with shape (B, T', D), where - output embedding, with shape (B, T', D), where
@ -1663,7 +1662,6 @@ class Emformer(EncoderInterface):
- output lengths, with shape (B,), without containing the - output lengths, with shape (B,), without containing the
right_context at the end. right_context at the end.
- updated states from current chunk's computation. - updated states from current chunk's computation.
- updated convolution caches from current chunk.
""" """
x = self.encoder_embed(x) x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -1674,16 +1672,13 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2 x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item() assert x.size(0) == x_lens.max().item()
( output, output_lengths, output_states = self.encoder.infer(
output, x, x_lens, states
output_lengths, )
output_states,
output_conv_caches,
) = self.encoder.infer(x, x_lens, states, conv_caches)
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return output, output_lengths, output_states, output_conv_caches return output, output_lengths, output_states
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):