mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
refactor, use fixed-length cache for batch decoding
This commit is contained in:
parent
10998bef69
commit
13899dff51
@ -200,7 +200,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
cache: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Causal convolution module.
|
"""Causal convolution module.
|
||||||
|
|
||||||
@ -209,14 +208,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
Utterance tensor of shape (U, B, D).
|
Utterance tensor of shape (U, B, D).
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context tensor of shape (R, B, D).
|
Right context tensor of shape (R, B, D).
|
||||||
cache (torch.Tensor, optional):
|
|
||||||
Cached tensor for left padding of shape (B, D, cache_size).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of 3 tensors:
|
A tuple of 2 tensors:
|
||||||
- output utterance of shape (U, B, D).
|
- output utterance of shape (U, B, D).
|
||||||
- output right_context of shape (R, B, D).
|
- output right_context of shape (R, B, D).
|
||||||
- updated cache tensor of shape (B, D, cache_size).
|
|
||||||
"""
|
"""
|
||||||
U, B, D = utterance.size()
|
U, B, D = utterance.size()
|
||||||
R, _, _ = right_context.size()
|
R, _, _ = right_context.size()
|
||||||
@ -230,17 +226,13 @@ class ConvolutionModule(nn.Module):
|
|||||||
utterance = x[:, :, R:] # (B, D, U)
|
utterance = x[:, :, R:] # (B, D, U)
|
||||||
right_context = x[:, :, :R] # (B, D, R)
|
right_context = x[:, :, :R] # (B, D, R)
|
||||||
|
|
||||||
if cache is None:
|
# make causal convolution
|
||||||
cache = torch.zeros(
|
cache = torch.zeros(
|
||||||
B, D, self.cache_size, device=x.device, dtype=x.dtype
|
B, D, self.cache_size, device=x.device, dtype=x.dtype
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
assert cache.shape == (B, D, self.cache_size), cache.shape
|
|
||||||
pad_utterance = torch.cat(
|
pad_utterance = torch.cat(
|
||||||
[cache, utterance], dim=2
|
[cache, utterance], dim=2
|
||||||
) # (B, D, cache + U)
|
) # (B, D, cache + U)
|
||||||
# update cache
|
|
||||||
new_cache = pad_utterance[:, :, -self.cache_size :]
|
|
||||||
|
|
||||||
# depth-wise conv on utterance
|
# depth-wise conv on utterance
|
||||||
utterance = self.depthwise_conv(pad_utterance) # (B, D, U)
|
utterance = self.depthwise_conv(pad_utterance) # (B, D, U)
|
||||||
@ -269,7 +261,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
return (
|
return (
|
||||||
utterance.permute(2, 0, 1),
|
utterance.permute(2, 0, 1),
|
||||||
right_context.permute(2, 0, 1),
|
right_context.permute(2, 0, 1),
|
||||||
new_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
@ -304,11 +295,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = self.deriv_balancer1(x)
|
x = self.deriv_balancer1(x)
|
||||||
x = nn.functional.glu(x, dim=1) # (B, D, U + R)
|
x = nn.functional.glu(x, dim=1) # (B, D, U + R)
|
||||||
|
|
||||||
if cache is None:
|
# make causal convolution
|
||||||
cache = torch.zeros(
|
|
||||||
B, D, self.cache_size, device=x.device, dtype=x.dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert cache.shape == (B, D, self.cache_size), cache.shape
|
assert cache.shape == (B, D, self.cache_size), cache.shape
|
||||||
x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R)
|
x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R)
|
||||||
# update cache
|
# update cache
|
||||||
@ -383,7 +370,7 @@ class EmformerAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
attention_weights: torch.Tensor,
|
attention_weights: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor],
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Given the entire attention weights, mask out unecessary connections
|
"""Given the entire attention weights, mask out unecessary connections
|
||||||
and optionally with padding positions, to obtain underlying chunk-wise
|
and optionally with padding positions, to obtain underlying chunk-wise
|
||||||
@ -438,11 +425,11 @@ class EmformerAttention(nn.Module):
|
|||||||
def _forward_impl(
|
def _forward_impl(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
summary: torch.Tensor,
|
summary: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
left_context_key: Optional[torch.Tensor] = None,
|
left_context_key: Optional[torch.Tensor] = None,
|
||||||
left_context_val: Optional[torch.Tensor] = None,
|
left_context_val: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
@ -470,7 +457,7 @@ class EmformerAttention(nn.Module):
|
|||||||
[value[: M + R], left_context_val, value[M + R :]]
|
[value[: M + R], left_context_val, value[M + R :]]
|
||||||
)
|
)
|
||||||
Q = query.size(0)
|
Q = query.size(0)
|
||||||
KV = key.size(0)
|
# KV = key.size(0)
|
||||||
|
|
||||||
reshaped_query, reshaped_key, reshaped_value = [
|
reshaped_query, reshaped_key, reshaped_value = [
|
||||||
tensor.contiguous()
|
tensor.contiguous()
|
||||||
@ -482,12 +469,6 @@ class EmformerAttention(nn.Module):
|
|||||||
reshaped_query * scaling, reshaped_key.transpose(1, 2)
|
reshaped_query * scaling, reshaped_key.transpose(1, 2)
|
||||||
) # (B * nhead, Q, KV)
|
) # (B * nhead, Q, KV)
|
||||||
|
|
||||||
# compute padding mask
|
|
||||||
if B == 1:
|
|
||||||
padding_mask = None
|
|
||||||
else:
|
|
||||||
padding_mask = make_pad_mask(KV - U + lengths)
|
|
||||||
|
|
||||||
# compute attention probabilities
|
# compute attention probabilities
|
||||||
attention_probs = self._gen_attention_probs(
|
attention_probs = self._gen_attention_probs(
|
||||||
attention_weights, attention_mask, padding_mask
|
attention_weights, attention_mask, padding_mask
|
||||||
@ -515,11 +496,11 @@ class EmformerAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
summary: torch.Tensor,
|
summary: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# TODO: Modify docs.
|
# TODO: Modify docs.
|
||||||
"""Forward pass for training and validation mode.
|
"""Forward pass for training and validation mode.
|
||||||
@ -560,9 +541,6 @@ class EmformerAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
Full utterance frames, with shape (U, B, D).
|
Full utterance frames, with shape (U, B, D).
|
||||||
lengths (torch.Tensor):
|
|
||||||
With shape (B,) and i-th element representing
|
|
||||||
number of valid frames for i-th batch element in utterance.
|
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Hard-copied right context frames, with shape (R, B, D),
|
Hard-copied right context frames, with shape (R, B, D),
|
||||||
where R = num_chunks * right_context_length
|
where R = num_chunks * right_context_length
|
||||||
@ -575,6 +553,8 @@ class EmformerAttention(nn.Module):
|
|||||||
attention_mask (torch.Tensor):
|
attention_mask (torch.Tensor):
|
||||||
Pre-computed attention mask to simulate underlying chunk-wise
|
Pre-computed attention mask to simulate underlying chunk-wise
|
||||||
attention, with shape (Q, KV).
|
attention, with shape (Q, KV).
|
||||||
|
padding_mask (torch.Tensor):
|
||||||
|
Padding mask of key tensor, with shape (B, KV).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 2 tensors:
|
A tuple containing 2 tensors:
|
||||||
@ -588,23 +568,23 @@ class EmformerAttention(nn.Module):
|
|||||||
_,
|
_,
|
||||||
) = self._forward_impl(
|
) = self._forward_impl(
|
||||||
utterance,
|
utterance,
|
||||||
lengths,
|
|
||||||
right_context,
|
right_context,
|
||||||
summary,
|
summary,
|
||||||
memory,
|
memory,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
return output_right_context_utterance, output_memory[:-1]
|
return output_right_context_utterance, output_memory[:-1]
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
summary: torch.Tensor,
|
summary: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
left_context_key: torch.Tensor,
|
left_context_key: torch.Tensor,
|
||||||
left_context_val: torch.Tensor,
|
left_context_val: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Forward pass for inference.
|
"""Forward pass for inference.
|
||||||
|
|
||||||
@ -633,9 +613,6 @@ class EmformerAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
Current chunk frames, with shape (U, B, D), where U = chunk_length.
|
Current chunk frames, with shape (U, B, D), where U = chunk_length.
|
||||||
lengths (torch.Tensor):
|
|
||||||
With shape (B,) and i-th element representing
|
|
||||||
number of valid frames for i-th batch element in utterance.
|
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D),
|
Right context frames, with shape (R, B, D),
|
||||||
where R = right_context_length.
|
where R = right_context_length.
|
||||||
@ -645,10 +622,12 @@ class EmformerAttention(nn.Module):
|
|||||||
Memory vectors, with shape (M, B, D), or empty tensor.
|
Memory vectors, with shape (M, B, D), or empty tensor.
|
||||||
left_context_key (torch,Tensor):
|
left_context_key (torch,Tensor):
|
||||||
Cached attention key of left context from preceding computation,
|
Cached attention key of left context from preceding computation,
|
||||||
with shape (L, B, D), where L <= left_context_length.
|
with shape (L, B, D).
|
||||||
left_context_val (torch.Tensor):
|
left_context_val (torch.Tensor):
|
||||||
Cached attention value of left context from preceding computation,
|
Cached attention value of left context from preceding computation,
|
||||||
with shape (L, B, D), where L <= left_context_length.
|
with shape (L, B, D).
|
||||||
|
padding_mask (torch.Tensor):
|
||||||
|
Padding mask of key tensor, with shape (B, KV).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 4 tensors:
|
A tuple containing 4 tensors:
|
||||||
@ -665,6 +644,7 @@ class EmformerAttention(nn.Module):
|
|||||||
S = summary.size(0)
|
S = summary.size(0)
|
||||||
M = memory.size(0)
|
M = memory.size(0)
|
||||||
|
|
||||||
|
# TODO: move it outside
|
||||||
# query = [right context, utterance, summary]
|
# query = [right context, utterance, summary]
|
||||||
Q = R + U + S
|
Q = R + U + S
|
||||||
# key, value = [memory, right context, left context, uttrance]
|
# key, value = [memory, right context, left context, uttrance]
|
||||||
@ -681,11 +661,11 @@ class EmformerAttention(nn.Module):
|
|||||||
value,
|
value,
|
||||||
) = self._forward_impl(
|
) = self._forward_impl(
|
||||||
utterance,
|
utterance,
|
||||||
lengths,
|
|
||||||
right_context,
|
right_context,
|
||||||
summary,
|
summary,
|
||||||
memory,
|
memory,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
padding_mask=padding_mask,
|
||||||
left_context_key=left_context_key,
|
left_context_key=left_context_key,
|
||||||
left_context_val=left_context_val,
|
left_context_val=left_context_val,
|
||||||
)
|
)
|
||||||
@ -719,8 +699,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
Length of left context. (Default: 0)
|
Length of left context. (Default: 0)
|
||||||
right_context_length (int, optional):
|
right_context_length (int, optional):
|
||||||
Length of right context. (Default: 0)
|
Length of right context. (Default: 0)
|
||||||
max_memory_size (int, optional):
|
memory_size (int, optional):
|
||||||
Maximum number of memory elements to use. (Default: 0)
|
Number of memory elements to use. (Default: 0)
|
||||||
tanh_on_mem (bool, optional):
|
tanh_on_mem (bool, optional):
|
||||||
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||||
negative_inf (float, optional):
|
negative_inf (float, optional):
|
||||||
@ -738,7 +718,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
left_context_length: int = 0,
|
left_context_length: int = 0,
|
||||||
right_context_length: int = 0,
|
right_context_length: int = 0,
|
||||||
max_memory_size: int = 0,
|
memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
):
|
):
|
||||||
@ -791,75 +771,29 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
self.layer_dropout = layer_dropout
|
self.layer_dropout = layer_dropout
|
||||||
self.left_context_length = left_context_length
|
self.left_context_length = left_context_length
|
||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
self.max_memory_size = max_memory_size
|
self.memory_size = memory_size
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.use_memory = max_memory_size > 0
|
self.use_memory = memory_size > 0
|
||||||
|
|
||||||
def _init_state(
|
def _update_attn_cache(
|
||||||
self, batch_size: int, device: Optional[torch.device]
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
"""Initialize states with zeros."""
|
|
||||||
empty_memory = torch.zeros(
|
|
||||||
self.max_memory_size, batch_size, self.d_model, device=device
|
|
||||||
)
|
|
||||||
left_context_key = torch.zeros(
|
|
||||||
self.left_context_length, batch_size, self.d_model, device=device
|
|
||||||
)
|
|
||||||
left_context_val = torch.zeros(
|
|
||||||
self.left_context_length, batch_size, self.d_model, device=device
|
|
||||||
)
|
|
||||||
past_length = torch.zeros(
|
|
||||||
1, batch_size, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
return [empty_memory, left_context_key, left_context_val, past_length]
|
|
||||||
|
|
||||||
def _unpack_state(
|
|
||||||
self, state: List[torch.Tensor]
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""Unpack cached states including:
|
|
||||||
1) output memory from previous chunks in the lower layer;
|
|
||||||
2) attention key and value of left context from preceding chunk's
|
|
||||||
computation.
|
|
||||||
"""
|
|
||||||
past_length = state[3][0][0].item()
|
|
||||||
past_left_context_length = min(self.left_context_length, past_length)
|
|
||||||
past_memory_length = min(
|
|
||||||
self.max_memory_size, math.ceil(past_length / self.chunk_length)
|
|
||||||
)
|
|
||||||
memory_start_idx = self.max_memory_size - past_memory_length
|
|
||||||
pre_memory = state[0][memory_start_idx:]
|
|
||||||
left_context_start_idx = (
|
|
||||||
self.left_context_length - past_left_context_length
|
|
||||||
)
|
|
||||||
left_context_key = state[1][left_context_start_idx:]
|
|
||||||
left_context_val = state[2][left_context_start_idx:]
|
|
||||||
return pre_memory, left_context_key, left_context_val
|
|
||||||
|
|
||||||
def _pack_state(
|
|
||||||
self,
|
self,
|
||||||
next_key: torch.Tensor,
|
next_key: torch.Tensor,
|
||||||
next_val: torch.Tensor,
|
next_val: torch.Tensor,
|
||||||
update_length: int,
|
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
state: List[torch.Tensor],
|
attn_cache: List[torch.Tensor],
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
"""Pack updated states including:
|
"""Update cached attention state:
|
||||||
1) output memory of current chunk in the lower layer;
|
1) output memory of current chunk in the lower layer;
|
||||||
2) attention key and value in current chunk's computation, which would
|
2) attention key and value in current chunk's computation, which would
|
||||||
be resued in next chunk's computation.
|
be resued in next chunk's computation.
|
||||||
3) length of current chunk.
|
|
||||||
"""
|
"""
|
||||||
new_memory = torch.cat([state[0], memory])
|
new_memory = torch.cat([attn_cache[0], memory])
|
||||||
new_key = torch.cat([state[1], next_key])
|
new_key = torch.cat([attn_cache[1], next_key])
|
||||||
new_val = torch.cat([state[2], next_val])
|
new_val = torch.cat([attn_cache[2], next_val])
|
||||||
memory_start_idx = new_memory.size(0) - self.max_memory_size
|
attn_cache[0] = new_memory[new_memory.size(0) - self.memory_size :]
|
||||||
state[0] = new_memory[memory_start_idx:]
|
attn_cache[1] = new_key[new_key.size(0) - self.left_context_length :]
|
||||||
key_start_idx = new_key.size(0) - self.left_context_length
|
attn_cache[2] = new_val[new_val.size(0) - self.left_context_length :]
|
||||||
state[1] = new_key[key_start_idx:]
|
return attn_cache
|
||||||
val_start_idx = new_val.size(0) - self.left_context_length
|
|
||||||
state[2] = new_val[val_start_idx:]
|
|
||||||
state[3] = state[3] + update_length
|
|
||||||
return state
|
|
||||||
|
|
||||||
def _apply_conv_module_forward(
|
def _apply_conv_module_forward(
|
||||||
self,
|
self,
|
||||||
@ -869,7 +803,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
"""Apply convolution module in training and validation mode."""
|
"""Apply convolution module in training and validation mode."""
|
||||||
utterance = right_context_utterance[R:]
|
utterance = right_context_utterance[R:]
|
||||||
right_context = right_context_utterance[:R]
|
right_context = right_context_utterance[:R]
|
||||||
utterance, right_context, _ = self.conv_module(utterance, right_context)
|
utterance, right_context = self.conv_module(utterance, right_context)
|
||||||
right_context_utterance = torch.cat([right_context, utterance])
|
right_context_utterance = torch.cat([right_context, utterance])
|
||||||
return right_context_utterance
|
return right_context_utterance
|
||||||
|
|
||||||
@ -892,15 +826,11 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
right_context_utterance: torch.Tensor,
|
right_context_utterance: torch.Tensor,
|
||||||
R: int,
|
R: int,
|
||||||
lengths: torch.Tensor,
|
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Apply attention module in training and validation mode."""
|
"""Apply attention module in training and validation mode."""
|
||||||
if attention_mask is None:
|
|
||||||
raise ValueError(
|
|
||||||
"attention_mask must be not None in training or validation mode." # noqa
|
|
||||||
)
|
|
||||||
utterance = right_context_utterance[R:]
|
utterance = right_context_utterance[R:]
|
||||||
right_context = right_context_utterance[:R]
|
right_context = right_context_utterance[:R]
|
||||||
|
|
||||||
@ -914,11 +844,11 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
output_right_context_utterance, output_memory = self.attention(
|
output_right_context_utterance, output_memory = self.attention(
|
||||||
utterance=utterance,
|
utterance=utterance,
|
||||||
lengths=lengths,
|
|
||||||
right_context=right_context,
|
right_context=right_context,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_right_context_utterance, output_memory
|
return output_right_context_utterance, output_memory
|
||||||
@ -927,9 +857,9 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
right_context_utterance: torch.Tensor,
|
right_context_utterance: torch.Tensor,
|
||||||
R: int,
|
R: int,
|
||||||
lengths: torch.Tensor,
|
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
state: Optional[List[torch.Tensor]] = None,
|
attn_cache: List[torch.Tensor],
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Apply attention module in inference mode.
|
"""Apply attention module in inference mode.
|
||||||
1) Unpack cached states including:
|
1) Unpack cached states including:
|
||||||
@ -937,7 +867,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
- attention key and value of left context from preceding
|
- attention key and value of left context from preceding
|
||||||
chunk's compuation;
|
chunk's compuation;
|
||||||
2) Apply attention computation;
|
2) Apply attention computation;
|
||||||
3) Pack updated states including:
|
3) Update cached attention states including:
|
||||||
- output memory of current chunk in the lower layer;
|
- output memory of current chunk in the lower layer;
|
||||||
- attention key and value in current chunk's computation, which would
|
- attention key and value in current chunk's computation, which would
|
||||||
be resued in next chunk's computation.
|
be resued in next chunk's computation.
|
||||||
@ -946,11 +876,10 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
utterance = right_context_utterance[R:]
|
utterance = right_context_utterance[R:]
|
||||||
right_context = right_context_utterance[:R]
|
right_context = right_context_utterance[:R]
|
||||||
|
|
||||||
if state is None:
|
pre_memory = attn_cache[0]
|
||||||
state = self._init_state(utterance.size(1), device=utterance.device)
|
left_context_key = attn_cache[1]
|
||||||
pre_memory, left_context_key, left_context_val = self._unpack_state(
|
left_context_val = attn_cache[2]
|
||||||
state
|
|
||||||
)
|
|
||||||
if self.use_memory:
|
if self.use_memory:
|
||||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||||
2, 0, 1
|
2, 0, 1
|
||||||
@ -967,25 +896,25 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
next_val,
|
next_val,
|
||||||
) = self.attention.infer(
|
) = self.attention.infer(
|
||||||
utterance=utterance,
|
utterance=utterance,
|
||||||
lengths=lengths,
|
|
||||||
right_context=right_context,
|
right_context=right_context,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
memory=pre_memory,
|
memory=pre_memory,
|
||||||
left_context_key=left_context_key,
|
left_context_key=left_context_key,
|
||||||
left_context_val=left_context_val,
|
left_context_val=left_context_val,
|
||||||
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
state = self._pack_state(
|
attn_cache = self._update_attn_cache(
|
||||||
next_key, next_val, utterance.size(0), memory, state
|
next_key, next_val, memory, attn_cache
|
||||||
)
|
)
|
||||||
return output_right_context_utterance, output_memory, state
|
return output_right_context_utterance, output_memory, attn_cache
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
r"""Forward pass for training and validation mode.
|
r"""Forward pass for training and validation mode.
|
||||||
@ -999,9 +928,6 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
Utterance frames, with shape (U, B, D).
|
Utterance frames, with shape (U, B, D).
|
||||||
lengths (torch.Tensor):
|
|
||||||
With shape (B,) and i-th element representing
|
|
||||||
number of valid frames for i-th batch element in utterance.
|
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
@ -1010,6 +936,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
attention_mask (torch.Tensor):
|
attention_mask (torch.Tensor):
|
||||||
Attention mask for underlying attention module,
|
Attention mask for underlying attention module,
|
||||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||||
|
padding_mask (torch.Tensor):
|
||||||
|
Padding mask of ker tensor, with shape (B, KV).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 3 tensors:
|
A tuple containing 3 tensors:
|
||||||
@ -1038,7 +966,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
# emformer attention module
|
# emformer attention module
|
||||||
src_att, output_memory = self._apply_attention_module_forward(
|
src_att, output_memory = self._apply_attention_module_forward(
|
||||||
src, R, lengths, memory, attention_mask
|
src, R, memory, attention_mask, padding_mask=padding_mask
|
||||||
)
|
)
|
||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
@ -1061,11 +989,11 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
state: Optional[List[torch.Tensor]] = None,
|
attn_cache: List[torch.Tensor],
|
||||||
conv_cache: Optional[torch.Tensor] = None,
|
conv_cache: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||||
"""Forward pass for inference.
|
"""Forward pass for inference.
|
||||||
|
|
||||||
@ -1078,18 +1006,17 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
Utterance frames, with shape (U, B, D).
|
Utterance frames, with shape (U, B, D).
|
||||||
lengths (torch.Tensor):
|
|
||||||
With shape (B,) and i-th element representing
|
|
||||||
number of valid frames for i-th batch element in utterance.
|
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory elements, with shape (M, B, D).
|
Memory elements, with shape (M, B, D).
|
||||||
state (List[torch.Tensor], optional):
|
attn_cache (List[torch.Tensor]):
|
||||||
List of tensors representing layer internal state generated in
|
Cached attention tensors generated in preceding computation,
|
||||||
preceding computation. (default=None)
|
including memory, key and value of left context.
|
||||||
conv_cache (torch.Tensor, optional):
|
conv_cache (torch.Tensor, optional):
|
||||||
Cache tensor of left context for causal convolution.
|
Cache tensor of left context for causal convolution.
|
||||||
|
padding_mask (torch.Tensor):
|
||||||
|
Padding mask of ker tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||||
@ -1109,8 +1036,10 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
(
|
(
|
||||||
src_att,
|
src_att,
|
||||||
output_memory,
|
output_memory,
|
||||||
output_state,
|
attn_cache,
|
||||||
) = self._apply_attention_module_infer(src, R, lengths, memory, state)
|
) = self._apply_attention_module_infer(
|
||||||
|
src, R, memory, attn_cache, padding_mask=padding_mask
|
||||||
|
)
|
||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
@ -1128,7 +1057,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
output_utterance,
|
output_utterance,
|
||||||
output_right_context,
|
output_right_context,
|
||||||
output_memory,
|
output_memory,
|
||||||
output_state,
|
attn_cache,
|
||||||
conv_cache,
|
conv_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1179,8 +1108,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
Length of left context. (default: 0)
|
Length of left context. (default: 0)
|
||||||
right_context_length (int, optional):
|
right_context_length (int, optional):
|
||||||
Length of right context. (default: 0)
|
Length of right context. (default: 0)
|
||||||
max_memory_size (int, optional):
|
memory_size (int, optional):
|
||||||
Maximum number of memory elements to use. (default: 0)
|
Number of memory elements to use. (default: 0)
|
||||||
tanh_on_mem (bool, optional):
|
tanh_on_mem (bool, optional):
|
||||||
If ``true``, applies tanh to memory elements. (default: ``false``)
|
If ``true``, applies tanh to memory elements. (default: ``false``)
|
||||||
negative_inf (float, optional):
|
negative_inf (float, optional):
|
||||||
@ -1199,13 +1128,13 @@ class EmformerEncoder(nn.Module):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
left_context_length: int = 0,
|
left_context_length: int = 0,
|
||||||
right_context_length: int = 0,
|
right_context_length: int = 0,
|
||||||
max_memory_size: int = 0,
|
memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.use_memory = max_memory_size > 0
|
self.use_memory = memory_size > 0
|
||||||
self.init_memory_op = nn.AvgPool1d(
|
self.init_memory_op = nn.AvgPool1d(
|
||||||
kernel_size=chunk_length,
|
kernel_size=chunk_length,
|
||||||
stride=chunk_length,
|
stride=chunk_length,
|
||||||
@ -1224,7 +1153,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
cnn_module_kernel=cnn_module_kernel,
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
left_context_length=left_context_length,
|
left_context_length=left_context_length,
|
||||||
right_context_length=right_context_length,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=max_memory_size,
|
memory_size=memory_size,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
)
|
)
|
||||||
@ -1232,10 +1161,13 @@ class EmformerEncoder(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.num_encoder_layers = num_encoder_layers
|
||||||
|
self.d_model = d_model
|
||||||
self.left_context_length = left_context_length
|
self.left_context_length = left_context_length
|
||||||
self.right_context_length = right_context_length
|
self.right_context_length = right_context_length
|
||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
self.max_memory_size = max_memory_size
|
self.memory_size = memory_size
|
||||||
|
self.cnn_module_kernel = cnn_module_kernel
|
||||||
|
|
||||||
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Hard copy each chunk's right context and concat them."""
|
"""Hard copy each chunk's right context and concat them."""
|
||||||
@ -1276,7 +1208,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
R = rc * num_chunks
|
R = rc * num_chunks
|
||||||
|
|
||||||
if self.use_memory:
|
if self.use_memory:
|
||||||
m_start = max(chunk_idx - self.max_memory_size, 0)
|
m_start = max(chunk_idx - self.memory_size, 0)
|
||||||
M = num_chunks - 1
|
M = num_chunks - 1
|
||||||
col_widths = [
|
col_widths = [
|
||||||
m_start, # before memory
|
m_start, # before memory
|
||||||
@ -1430,15 +1362,18 @@ class EmformerEncoder(nn.Module):
|
|||||||
if self.use_memory
|
if self.use_memory
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
)
|
)
|
||||||
|
padding_mask = make_pad_mask(
|
||||||
|
memory.size(0) + right_context.size(0) + output_lengths
|
||||||
|
)
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
for layer in self.emformer_layers:
|
for layer in self.emformer_layers:
|
||||||
output, right_context, memory = layer(
|
output, right_context, memory = layer(
|
||||||
output,
|
output,
|
||||||
output_lengths,
|
|
||||||
right_context,
|
right_context,
|
||||||
memory,
|
memory,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
padding_mask=padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1448,10 +1383,13 @@ class EmformerEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
states: Optional[List[List[torch.Tensor]]] = None,
|
states: List[
|
||||||
conv_caches: Optional[List[torch.Tensor]] = None,
|
torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
|
||||||
|
],
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
"""Forward pass for streaming inference.
|
"""Forward pass for streaming inference.
|
||||||
|
|
||||||
@ -1467,13 +1405,13 @@ class EmformerEncoder(nn.Module):
|
|||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x, which contains the
|
utterance frames for i-th batch element in x, which contains the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
states (List[List[torch.Tensor]], optional):
|
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
|
||||||
Cached states from preceding chunk's computation, where each
|
Cached states containing:
|
||||||
element (List[torch.Tensor]) corresponds to each emformer layer.
|
- past_lens: number of past frames for each sample in batch
|
||||||
(default: None)
|
- attn_caches: attention states from preceding chunk's computation,
|
||||||
conv_caches (List[torch.Tensor], optional):
|
where each element corresponds to each emformer layer
|
||||||
Cached tensors of left context for causal convolution, where each
|
- conv_caches: left context for causal convolution, where each
|
||||||
element (Tensor) corresponds to each convolutional layer.
|
element corresponds to each layer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
|
(Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
|
||||||
@ -1481,8 +1419,38 @@ class EmformerEncoder(nn.Module):
|
|||||||
- output lengths, with shape (B,), without containing the
|
- output lengths, with shape (B,), without containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
- updated states from current chunk's computation.
|
- updated states from current chunk's computation.
|
||||||
- updated convolution caches from current chunk.
|
|
||||||
"""
|
"""
|
||||||
|
past_lens = states[0]
|
||||||
|
assert past_lens.shape == (x.size(1),), past_lens.shape
|
||||||
|
|
||||||
|
attn_caches = states[1]
|
||||||
|
assert len(attn_caches) == self.num_encoder_layers, len(attn_caches)
|
||||||
|
for i in range(len(attn_caches)):
|
||||||
|
assert attn_caches[i][0].shape == (
|
||||||
|
self.memory_size,
|
||||||
|
x.size(1),
|
||||||
|
self.d_model,
|
||||||
|
), attn_caches[i][0].shape
|
||||||
|
assert attn_caches[i][1].shape == (
|
||||||
|
self.left_context_length,
|
||||||
|
x.size(1),
|
||||||
|
self.d_model,
|
||||||
|
), attn_caches[i][1].shape
|
||||||
|
assert attn_caches[i][2].shape == (
|
||||||
|
self.left_context_length,
|
||||||
|
x.size(1),
|
||||||
|
self.d_model,
|
||||||
|
), attn_caches[i][2].shape
|
||||||
|
|
||||||
|
conv_caches = states[2]
|
||||||
|
assert len(conv_caches) == self.num_encoder_layers, len(conv_caches)
|
||||||
|
for i in range(len(conv_caches)):
|
||||||
|
assert conv_caches[i].shape == (
|
||||||
|
x.size(1),
|
||||||
|
self.d_model,
|
||||||
|
self.cnn_module_kernel,
|
||||||
|
), conv_caches[i].shape
|
||||||
|
|
||||||
assert x.size(0) == self.chunk_length + self.right_context_length, (
|
assert x.size(0) == self.chunk_length + self.right_context_length, (
|
||||||
"Per configured chunk_length and right_context_length, "
|
"Per configured chunk_length and right_context_length, "
|
||||||
f"expected size of {self.chunk_length + self.right_context_length} "
|
f"expected size of {self.chunk_length + self.right_context_length} "
|
||||||
@ -1498,28 +1466,60 @@ class EmformerEncoder(nn.Module):
|
|||||||
if self.use_memory
|
if self.use_memory
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# calcualte padding mask
|
||||||
|
chunk_mask = make_pad_mask(output_lengths)
|
||||||
|
memory_mask = (
|
||||||
|
(past_lens // self.chunk_length).view(x.size(1), 1)
|
||||||
|
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||||
|
x.size(1), self.memory_size
|
||||||
|
)
|
||||||
|
).flip(1)
|
||||||
|
left_context_mask = (
|
||||||
|
past_lens.view(x.size(1), 1)
|
||||||
|
<= torch.arange(self.left_context_length, device=x.device).expand(
|
||||||
|
x.size(1), self.left_context_length
|
||||||
|
)
|
||||||
|
).flip(1)
|
||||||
|
right_context_mask = torch.zeros(
|
||||||
|
x.size(1),
|
||||||
|
self.right_context_length,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
padding_mask = torch.cat(
|
||||||
|
[memory_mask, left_context_mask, right_context_mask, chunk_mask],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
output_states: List[List[torch.Tensor]] = []
|
output_attn_caches: List[List[torch.Tensor]] = []
|
||||||
output_conv_caches: List[torch.Tensor] = []
|
output_conv_caches: List[torch.Tensor] = []
|
||||||
for layer_idx, layer in enumerate(self.emformer_layers):
|
for layer_idx, layer in enumerate(self.emformer_layers):
|
||||||
(
|
(
|
||||||
output,
|
output,
|
||||||
right_context,
|
right_context,
|
||||||
memory,
|
memory,
|
||||||
output_state,
|
output_attn_cache,
|
||||||
output_conv_cache,
|
output_conv_cache,
|
||||||
) = layer.infer(
|
) = layer.infer(
|
||||||
output,
|
output,
|
||||||
output_lengths,
|
|
||||||
right_context,
|
right_context,
|
||||||
memory,
|
memory,
|
||||||
None if states is None else states[layer_idx],
|
padding_mask=padding_mask,
|
||||||
None if conv_caches is None else conv_caches[layer_idx],
|
attn_cache=attn_caches[layer_idx],
|
||||||
|
conv_cache=conv_caches[layer_idx],
|
||||||
)
|
)
|
||||||
output_states.append(output_state)
|
output_attn_caches.append(output_attn_cache)
|
||||||
output_conv_caches.append(output_conv_cache)
|
output_conv_caches.append(output_conv_cache)
|
||||||
|
|
||||||
return output, output_lengths, output_states, output_conv_caches
|
output_past_lens = past_lens + output_lengths
|
||||||
|
output_states = [
|
||||||
|
output_past_lens,
|
||||||
|
output_attn_caches,
|
||||||
|
output_conv_caches,
|
||||||
|
]
|
||||||
|
return output, output_lengths, output_states
|
||||||
|
|
||||||
|
|
||||||
class Emformer(EncoderInterface):
|
class Emformer(EncoderInterface):
|
||||||
@ -1537,7 +1537,7 @@ class Emformer(EncoderInterface):
|
|||||||
cnn_module_kernel: int = 3,
|
cnn_module_kernel: int = 3,
|
||||||
left_context_length: int = 0,
|
left_context_length: int = 0,
|
||||||
right_context_length: int = 0,
|
right_context_length: int = 0,
|
||||||
max_memory_size: int = 0,
|
memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
):
|
):
|
||||||
@ -1576,7 +1576,7 @@ class Emformer(EncoderInterface):
|
|||||||
cnn_module_kernel=cnn_module_kernel,
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
left_context_length=left_context_length // 4,
|
left_context_length=left_context_length // 4,
|
||||||
right_context_length=right_context_length // 4,
|
right_context_length=right_context_length // 4,
|
||||||
max_memory_size=max_memory_size,
|
memory_size=memory_size,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
)
|
)
|
||||||
@ -1633,7 +1633,6 @@ class Emformer(EncoderInterface):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
states: Optional[List[List[torch.Tensor]]] = None,
|
states: Optional[List[List[torch.Tensor]]] = None,
|
||||||
conv_caches: Optional[List[torch.Tensor]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
||||||
"""Forward pass for streaming inference.
|
"""Forward pass for streaming inference.
|
||||||
|
|
||||||
@ -1649,13 +1648,13 @@ class Emformer(EncoderInterface):
|
|||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x, containing the
|
utterance frames for i-th batch element in x, containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
states (List[List[torch.Tensor]], optional):
|
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
|
||||||
Cached states from preceding chunk's computation, where each
|
Cached states containing:
|
||||||
element (List[torch.Tensor]) corresponds to each emformer layer.
|
- past_lens: number of past frames for each sample in batch
|
||||||
(default: None)
|
- attn_caches: attention states from preceding chunk's computation,
|
||||||
conv_caches (List[torch.Tensor], optional):
|
where each element corresponds to each emformer layer
|
||||||
Cached tensors of left context for causal convolution, where each
|
- conv_caches: left context for causal convolution, where each
|
||||||
element (Tensor) corresponds to each convolutional layer.
|
element corresponds to each layer.
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor):
|
(Tensor, Tensor):
|
||||||
- output embedding, with shape (B, T', D), where
|
- output embedding, with shape (B, T', D), where
|
||||||
@ -1663,7 +1662,6 @@ class Emformer(EncoderInterface):
|
|||||||
- output lengths, with shape (B,), without containing the
|
- output lengths, with shape (B,), without containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
- updated states from current chunk's computation.
|
- updated states from current chunk's computation.
|
||||||
- updated convolution caches from current chunk.
|
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
@ -1674,16 +1672,13 @@ class Emformer(EncoderInterface):
|
|||||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == x_lens.max().item()
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
(
|
output, output_lengths, output_states = self.encoder.infer(
|
||||||
output,
|
x, x_lens, states
|
||||||
output_lengths,
|
)
|
||||||
output_states,
|
|
||||||
output_conv_caches,
|
|
||||||
) = self.encoder.infer(x, x_lens, states, conv_caches)
|
|
||||||
|
|
||||||
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||||
|
|
||||||
return output, output_lengths, output_states, output_conv_caches
|
return output, output_lengths, output_states
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user