mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add Emformer module
This commit is contained in:
parent
943cb9d5a3
commit
648a0b37d5
@ -1303,7 +1303,6 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
output_right_context = src[:R]
|
output_right_context = src[:R]
|
||||||
return output_utterance, output_right_context, output_memory
|
return output_utterance, output_right_context, output_memory
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
@ -1383,3 +1382,642 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
output_state,
|
output_state,
|
||||||
conv_cache,
|
conv_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_attention_mask_block(
|
||||||
|
col_widths: List[int],
|
||||||
|
col_mask: List[bool],
|
||||||
|
num_rows: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert len(col_widths) == len(
|
||||||
|
col_mask
|
||||||
|
), "Length of col_widths must match that of col_mask"
|
||||||
|
|
||||||
|
mask_block = [
|
||||||
|
torch.ones(num_rows, col_width, device=device)
|
||||||
|
if is_ones_col
|
||||||
|
else torch.zeros(num_rows, col_width, device=device)
|
||||||
|
for col_width, is_ones_col in zip(col_widths, col_mask)
|
||||||
|
]
|
||||||
|
return torch.cat(mask_block, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class EmformerEncoder(nn.Module):
|
||||||
|
"""Implements the Emformer architecture introduced in
|
||||||
|
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency
|
||||||
|
Streaming Speech Recognition*
|
||||||
|
[:footcite:`shi2021emformer`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model (int):
|
||||||
|
Input dimension.
|
||||||
|
nhead (int):
|
||||||
|
Number of attention heads in each emformer layer.
|
||||||
|
dim_feedforward (int):
|
||||||
|
Hidden layer dimension of each emformer layer's feedforward network.
|
||||||
|
num_encoder_layers (int):
|
||||||
|
Number of emformer layers to instantiate.
|
||||||
|
chunk_length (int):
|
||||||
|
Length of each input segment.
|
||||||
|
dropout (float, optional):
|
||||||
|
Dropout probability. (default: 0.0)
|
||||||
|
layer_dropout (float, optional):
|
||||||
|
Layer dropout probability. (default: 0.0)
|
||||||
|
cnn_module_kernel (int):
|
||||||
|
Kernel size of convolution module.
|
||||||
|
left_context_length (int, optional):
|
||||||
|
Length of left context. (default: 0)
|
||||||
|
right_context_length (int, optional):
|
||||||
|
Length of right context. (default: 0)
|
||||||
|
max_memory_size (int, optional):
|
||||||
|
Maximum number of memory elements to use. (default: 0)
|
||||||
|
tanh_on_mem (bool, optional):
|
||||||
|
If ``true``, applies tanh to memory elements. (default: ``false``)
|
||||||
|
negative_inf (float, optional):
|
||||||
|
Value to use for negative infinity in attention weights. (default: -1e8)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_length: int,
|
||||||
|
d_model: int = 256,
|
||||||
|
nhead: int = 4,
|
||||||
|
dim_feedforward: int = 2048,
|
||||||
|
num_encoder_layers: int = 12,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
layer_dropout: float = 0.075,
|
||||||
|
cnn_module_kernel: int = 31,
|
||||||
|
left_context_length: int = 0,
|
||||||
|
right_context_length: int = 0,
|
||||||
|
max_memory_size: int = 0,
|
||||||
|
tanh_on_mem: bool = False,
|
||||||
|
negative_inf: float = -1e8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.use_memory = max_memory_size > 0
|
||||||
|
self.init_memory_op = nn.AvgPool1d(
|
||||||
|
kernel_size=chunk_length,
|
||||||
|
stride=chunk_length,
|
||||||
|
ceil_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.emformer_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
EmformerEncoderLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
dropout=dropout,
|
||||||
|
layer_dropout=layer_dropout,
|
||||||
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
max_memory_size=max_memory_size,
|
||||||
|
tanh_on_mem=tanh_on_mem,
|
||||||
|
negative_inf=negative_inf,
|
||||||
|
)
|
||||||
|
for layer_idx in range(num_encoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
|
self.left_context_length = left_context_length
|
||||||
|
self.right_context_length = right_context_length
|
||||||
|
self.chunk_length = chunk_length
|
||||||
|
self.max_memory_size = max_memory_size
|
||||||
|
|
||||||
|
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Hard copy each chunk's right context and concat them."""
|
||||||
|
T = x.shape[0]
|
||||||
|
num_chunks = math.ceil(
|
||||||
|
(T - self.right_context_length) / self.chunk_length
|
||||||
|
)
|
||||||
|
right_context_blocks = []
|
||||||
|
for seg_idx in range(num_chunks - 1):
|
||||||
|
start = (seg_idx + 1) * self.chunk_length
|
||||||
|
end = start + self.right_context_length
|
||||||
|
right_context_blocks.append(x[start:end])
|
||||||
|
right_context_blocks.append(x[T - self.right_context_length :])
|
||||||
|
return torch.cat(right_context_blocks)
|
||||||
|
|
||||||
|
def _gen_attention_mask_col_widths(
|
||||||
|
self, chunk_idx: int, U: int
|
||||||
|
) -> List[int]:
|
||||||
|
"""Calculate column widths (key, value) in attention mask for the
|
||||||
|
chunk_idx chunk."""
|
||||||
|
num_chunks = math.ceil(U / self.chunk_length)
|
||||||
|
rc = self.right_context_length
|
||||||
|
lc = self.left_context_length
|
||||||
|
rc_start = chunk_idx * rc
|
||||||
|
rc_end = rc_start + rc
|
||||||
|
chunk_start = max(chunk_idx * self.chunk_length - lc, 0)
|
||||||
|
chunk_end = min((chunk_idx + 1) * self.chunk_length, U)
|
||||||
|
R = rc * num_chunks
|
||||||
|
|
||||||
|
if self.use_memory:
|
||||||
|
m_start = max(chunk_idx - self.max_memory_size, 0)
|
||||||
|
M = num_chunks - 1
|
||||||
|
col_widths = [
|
||||||
|
m_start, # before memory
|
||||||
|
chunk_idx - m_start, # memory
|
||||||
|
M - chunk_idx, # after memory
|
||||||
|
rc_start, # before right context
|
||||||
|
rc, # right context
|
||||||
|
R - rc_end, # after right context
|
||||||
|
chunk_start, # before chunk
|
||||||
|
chunk_end - chunk_start, # chunk
|
||||||
|
U - chunk_end, # after chunk
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
col_widths = [
|
||||||
|
rc_start, # before right context
|
||||||
|
rc, # right context
|
||||||
|
R - rc_end, # after right context
|
||||||
|
chunk_start, # before chunk
|
||||||
|
chunk_end - chunk_start, # chunk
|
||||||
|
U - chunk_end, # after chunk
|
||||||
|
]
|
||||||
|
|
||||||
|
return col_widths
|
||||||
|
|
||||||
|
def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Generate attention mask to simulate underlying chunk-wise attention
|
||||||
|
computation, where chunk-wise connections are filled with `False`,
|
||||||
|
and other unnecessary connections beyond chunk are filled with `True`.
|
||||||
|
|
||||||
|
R: length of hard-copied right contexts;
|
||||||
|
U: length of full utterance;
|
||||||
|
S: length of summary vectors;
|
||||||
|
M: length of memory vectors;
|
||||||
|
Q: length of attention query;
|
||||||
|
KV: length of attention key and value.
|
||||||
|
|
||||||
|
The shape of attention mask is (Q, KV).
|
||||||
|
If self.use_memory is `True`:
|
||||||
|
query = [right_context, utterance, summary];
|
||||||
|
key, value = [memory, right_context, utterance];
|
||||||
|
Q = R + U + S, KV = M + R + U.
|
||||||
|
Otherwise:
|
||||||
|
query = [right_context, utterance]
|
||||||
|
key, value = [right_context, utterance]
|
||||||
|
Q = R + U, KV = R + U.
|
||||||
|
|
||||||
|
Suppose:
|
||||||
|
c_i: chunk at index i;
|
||||||
|
r_i: right context that c_i can use;
|
||||||
|
l_i: left context that c_i can use;
|
||||||
|
m_i: past memory vectors from previous layer that c_i can use;
|
||||||
|
s_i: summary vector of c_i.
|
||||||
|
The target chunk-wise attention is:
|
||||||
|
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
|
||||||
|
s_i (in query) -> l_i, c_i, r_i (in key).
|
||||||
|
"""
|
||||||
|
U = utterance.size(0)
|
||||||
|
num_chunks = math.ceil(U / self.chunk_length)
|
||||||
|
|
||||||
|
right_context_mask = []
|
||||||
|
utterance_mask = []
|
||||||
|
summary_mask = []
|
||||||
|
|
||||||
|
if self.use_memory:
|
||||||
|
num_cols = 9
|
||||||
|
# right context and utterance both attend to memory, right context,
|
||||||
|
# utterance
|
||||||
|
right_context_utterance_cols_mask = [
|
||||||
|
idx in [1, 4, 7] for idx in range(num_cols)
|
||||||
|
]
|
||||||
|
# summary attends to right context, utterance
|
||||||
|
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
||||||
|
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
|
||||||
|
else:
|
||||||
|
num_cols = 6
|
||||||
|
# right context and utterance both attend to right context and
|
||||||
|
# utterance
|
||||||
|
right_context_utterance_cols_mask = [
|
||||||
|
idx in [1, 4] for idx in range(num_cols)
|
||||||
|
]
|
||||||
|
summary_cols_mask = None
|
||||||
|
masks_to_concat = [right_context_mask, utterance_mask]
|
||||||
|
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
col_widths = self._gen_attention_mask_col_widths(chunk_idx, U)
|
||||||
|
|
||||||
|
right_context_mask_block = _gen_attention_mask_block(
|
||||||
|
col_widths,
|
||||||
|
right_context_utterance_cols_mask,
|
||||||
|
self.right_context_length,
|
||||||
|
utterance.device,
|
||||||
|
)
|
||||||
|
right_context_mask.append(right_context_mask_block)
|
||||||
|
|
||||||
|
utterance_mask_block = _gen_attention_mask_block(
|
||||||
|
col_widths,
|
||||||
|
right_context_utterance_cols_mask,
|
||||||
|
min(
|
||||||
|
self.chunk_length,
|
||||||
|
U - chunk_idx * self.chunk_length,
|
||||||
|
),
|
||||||
|
utterance.device,
|
||||||
|
)
|
||||||
|
utterance_mask.append(utterance_mask_block)
|
||||||
|
|
||||||
|
if summary_cols_mask is not None:
|
||||||
|
summary_mask_block = _gen_attention_mask_block(
|
||||||
|
col_widths, summary_cols_mask, 1, utterance.device
|
||||||
|
)
|
||||||
|
summary_mask.append(summary_mask_block)
|
||||||
|
|
||||||
|
attention_mask = (
|
||||||
|
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
|
||||||
|
).to(torch.bool)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, lengths: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward pass for training and validation mode.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
D: input dimension;
|
||||||
|
U: length of utterance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor):
|
||||||
|
Utterance frames right-padded with right context frames,
|
||||||
|
with shape (U + right_context_length, B, D).
|
||||||
|
lengths (torch.Tensor):
|
||||||
|
With shape (B,) and i-th element representing number of valid
|
||||||
|
utterance frames for i-th batch element in x, which contains the
|
||||||
|
right_context at the end.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of 2 tensors:
|
||||||
|
- output utterance frames, with shape (U, B, D).
|
||||||
|
- output_lengths, with shape (B,), without containing the
|
||||||
|
right_context at the end.
|
||||||
|
"""
|
||||||
|
U = x.size(0) - self.right_context_length
|
||||||
|
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||||
|
|
||||||
|
right_context = self._gen_right_context(x)
|
||||||
|
utterance = x[:U]
|
||||||
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
|
memory = (
|
||||||
|
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||||
|
:-1
|
||||||
|
]
|
||||||
|
if self.use_memory
|
||||||
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
output = utterance
|
||||||
|
for layer in self.emformer_layers:
|
||||||
|
output, right_context, memory = layer(
|
||||||
|
output,
|
||||||
|
output_lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output, output_lengths
|
||||||
|
|
||||||
|
def infer(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
states: Optional[List[List[torch.Tensor]]] = None,
|
||||||
|
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
|
||||||
|
]:
|
||||||
|
"""Forward pass for streaming inference.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
D: input dimension;
|
||||||
|
U: length of utterance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor):
|
||||||
|
Utterance frames right-padded with right context frames,
|
||||||
|
with shape (U + right_context_length, B, D).
|
||||||
|
lengths (torch.Tensor):
|
||||||
|
With shape (B,) and i-th element representing number of valid
|
||||||
|
utterance frames for i-th batch element in x, which contains the
|
||||||
|
right_context at the end.
|
||||||
|
states (List[List[torch.Tensor]], optional):
|
||||||
|
Cached states from proceeding chunk's computation, where each
|
||||||
|
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||||
|
(default: None)
|
||||||
|
conv_caches (List[torch.Tensor], optional):
|
||||||
|
Cached tensors of left context for causal convolution, where each
|
||||||
|
element (Tensor) corresponds to each convolutional layer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
|
||||||
|
- output utterance frames, with shape (U, B, D).
|
||||||
|
- output lengths, with shape (B,), without containing the
|
||||||
|
right_context at the end.
|
||||||
|
- updated states from current chunk's computation.
|
||||||
|
- updated convolution caches from current chunk.
|
||||||
|
"""
|
||||||
|
assert x.size(0) == self.chunk_length + self.right_context_length, (
|
||||||
|
"Per configured chunk_length and right_context_length, "
|
||||||
|
f"expected size of {self.chunk_length + self.right_context_length} "
|
||||||
|
f"for dimension 1 of x, but got {x.size(1)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_len = self.chunk_length + self.left_context_length
|
||||||
|
neg_len = self.chunk_length
|
||||||
|
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||||
|
|
||||||
|
right_context_start_idx = x.size(0) - self.right_context_length
|
||||||
|
right_context = x[right_context_start_idx:]
|
||||||
|
utterance = x[:right_context_start_idx]
|
||||||
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
|
memory = (
|
||||||
|
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
||||||
|
if self.use_memory
|
||||||
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
|
)
|
||||||
|
output = utterance
|
||||||
|
output_states: List[List[torch.Tensor]] = []
|
||||||
|
output_conv_caches: List[torch.Tensor] = []
|
||||||
|
for layer_idx, layer in enumerate(self.emformer_layers):
|
||||||
|
(
|
||||||
|
output,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
output_state,
|
||||||
|
output_conv_cache,
|
||||||
|
) = layer.infer(
|
||||||
|
output,
|
||||||
|
output_lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
pos_emb,
|
||||||
|
None if states is None else states[layer_idx],
|
||||||
|
None if conv_caches is None else conv_caches[layer_idx],
|
||||||
|
)
|
||||||
|
output_states.append(output_state)
|
||||||
|
output_conv_caches.append(output_conv_cache)
|
||||||
|
|
||||||
|
return output, output_lengths, output_states, output_conv_caches
|
||||||
|
|
||||||
|
|
||||||
|
class Emformer(EncoderInterface):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
chunk_length: int,
|
||||||
|
subsampling_factor: int = 4,
|
||||||
|
d_model: int = 256,
|
||||||
|
nhead: int = 4,
|
||||||
|
dim_feedforward: int = 2048,
|
||||||
|
num_encoder_layers: int = 12,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
layer_dropout: float = 0.075,
|
||||||
|
cnn_module_kernel: int = 3,
|
||||||
|
left_context_length: int = 0,
|
||||||
|
right_context_length: int = 0,
|
||||||
|
max_memory_size: int = 0,
|
||||||
|
tanh_on_mem: bool = False,
|
||||||
|
negative_inf: float = -1e8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.subsampling_factor = subsampling_factor
|
||||||
|
self.right_context_length = right_context_length
|
||||||
|
if subsampling_factor != 4:
|
||||||
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
|
if chunk_length % 4 != 0:
|
||||||
|
raise NotImplementedError("chunk_length must be a mutiple of 4.")
|
||||||
|
if left_context_length != 0 and left_context_length % 4 != 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"left_context_length must be 0 or a mutiple of 4."
|
||||||
|
)
|
||||||
|
if right_context_length != 0 and right_context_length % 4 != 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"right_context_length must be 0 or a mutiple of 4."
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||||
|
# to the shape (N, T//subsampling_factor, d_model).
|
||||||
|
# That is, it does two things simultaneously:
|
||||||
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
|
# (2) embedding: num_features -> d_model
|
||||||
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
|
self.encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length // 4,
|
||||||
|
d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
layer_dropout=layer_dropout,
|
||||||
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
|
left_context_length=left_context_length // 4,
|
||||||
|
right_context_length=right_context_length // 4,
|
||||||
|
max_memory_size=max_memory_size,
|
||||||
|
tanh_on_mem=tanh_on_mem,
|
||||||
|
negative_inf=negative_inf,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward pass for training and non-streaming inference.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
D: feature dimension;
|
||||||
|
T: length of utterance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor):
|
||||||
|
Utterance frames right-padded with right context frames,
|
||||||
|
with shape (B, T, D).
|
||||||
|
x_lens (torch.Tensor):
|
||||||
|
With shape (B,) and i-th element representing number of valid
|
||||||
|
utterance frames for i-th batch element in x, containing the
|
||||||
|
right_context at the end.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor, Tensor):
|
||||||
|
- output embedding, with shape (B, T', D), where
|
||||||
|
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
|
||||||
|
- output lengths, with shape (B,), without containing the
|
||||||
|
right_context at the end.
|
||||||
|
"""
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||||
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
|
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
||||||
|
|
||||||
|
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||||
|
|
||||||
|
return output, output_lengths
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def infer(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
states: Optional[List[List[torch.Tensor]]] = None,
|
||||||
|
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
||||||
|
"""Forward pass for streaming inference.
|
||||||
|
|
||||||
|
B: batch size;
|
||||||
|
D: feature dimension;
|
||||||
|
T: length of utterance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor):
|
||||||
|
Utterance frames right-padded with right context frames,
|
||||||
|
with shape (B, T, D).
|
||||||
|
lengths (torch.Tensor):
|
||||||
|
With shape (B,) and i-th element representing number of valid
|
||||||
|
utterance frames for i-th batch element in x, containing the
|
||||||
|
right_context at the end.
|
||||||
|
states (List[List[torch.Tensor]], optional):
|
||||||
|
Cached states from proceeding chunk's computation, where each
|
||||||
|
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||||
|
(default: None)
|
||||||
|
conv_caches (List[torch.Tensor], optional):
|
||||||
|
Cached tensors of left context for causal convolution, where each
|
||||||
|
element (Tensor) corresponds to each convolutional layer.
|
||||||
|
Returns:
|
||||||
|
(Tensor, Tensor):
|
||||||
|
- output embedding, with shape (B, T', D), where
|
||||||
|
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
|
||||||
|
- output lengths, with shape (B,), without containing the
|
||||||
|
right_context at the end.
|
||||||
|
- updated states from current chunk's computation.
|
||||||
|
- updated convolution caches from current chunk.
|
||||||
|
"""
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||||
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
|
(
|
||||||
|
output,
|
||||||
|
output_lengths,
|
||||||
|
output_states,
|
||||||
|
output_conv_caches,
|
||||||
|
) = self.encoder.infer(x, x_lens, states, conv_caches)
|
||||||
|
|
||||||
|
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||||
|
|
||||||
|
return output, output_lengths, output_states, output_conv_caches
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/4 length).
|
||||||
|
|
||||||
|
Convert an input of shape (N, T, idim) to an output
|
||||||
|
with shape (N, T', odim), where
|
||||||
|
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||||
|
|
||||||
|
It is based on
|
||||||
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
layer1_channels: int = 8,
|
||||||
|
layer2_channels: int = 32,
|
||||||
|
layer3_channels: int = 128,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_channels:
|
||||||
|
Number of channels in. The input shape is (N, T, in_channels).
|
||||||
|
Caution: It requires: T >=7, in_channels >=7
|
||||||
|
out_channels
|
||||||
|
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer1
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer2
|
||||||
|
"""
|
||||||
|
assert in_channels >= 7
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
ScaledConv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=layer1_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
ActivationBalancer(channel_dim=1),
|
||||||
|
DoubleSwish(),
|
||||||
|
ScaledConv2d(
|
||||||
|
in_channels=layer1_channels,
|
||||||
|
out_channels=layer2_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
ActivationBalancer(channel_dim=1),
|
||||||
|
DoubleSwish(),
|
||||||
|
ScaledConv2d(
|
||||||
|
in_channels=layer2_channels,
|
||||||
|
out_channels=layer3_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
ActivationBalancer(channel_dim=1),
|
||||||
|
DoubleSwish(),
|
||||||
|
)
|
||||||
|
self.out = ScaledLinear(
|
||||||
|
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||||
|
)
|
||||||
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
|
# needed.
|
||||||
|
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||||
|
# constrain median of output to be close to zero.
|
||||||
|
self.out_balancer = ActivationBalancer(
|
||||||
|
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is (N, T, idim).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||||
|
"""
|
||||||
|
# On entry, x is (N, T, idim)
|
||||||
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
|
x = self.conv(x)
|
||||||
|
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
|
x = self.out_norm(x)
|
||||||
|
x = self.out_balancer(x)
|
||||||
|
return x
|
||||||
|
@ -113,7 +113,10 @@ def test_convolution_module_forward():
|
|||||||
R = num_chunks * right_context_length
|
R = num_chunks * right_context_length
|
||||||
kernel_size = 31
|
kernel_size = 31
|
||||||
conv_module = ConvolutionModule(
|
conv_module = ConvolutionModule(
|
||||||
chunk_length, right_context_length, D, kernel_size,
|
chunk_length,
|
||||||
|
right_context_length,
|
||||||
|
D,
|
||||||
|
kernel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
utterance = torch.randn(U, B, D)
|
utterance = torch.randn(U, B, D)
|
||||||
@ -139,7 +142,10 @@ def test_convolution_module_infer():
|
|||||||
R = num_chunks * right_context_length
|
R = num_chunks * right_context_length
|
||||||
kernel_size = 31
|
kernel_size = 31
|
||||||
conv_module = ConvolutionModule(
|
conv_module = ConvolutionModule(
|
||||||
chunk_length, right_context_length, D, kernel_size,
|
chunk_length,
|
||||||
|
right_context_length,
|
||||||
|
D,
|
||||||
|
kernel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
utterance = torch.randn(U, B, D)
|
utterance = torch.randn(U, B, D)
|
||||||
@ -274,6 +280,260 @@ def test_emformer_encoder_layer_infer():
|
|||||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_forward():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
chunk_length = 4
|
||||||
|
right_context_length = 2
|
||||||
|
left_context_length = 2
|
||||||
|
num_chunks = 3
|
||||||
|
U = num_chunks * chunk_length
|
||||||
|
kernel_size = 31
|
||||||
|
num_encoder_layers = 2
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
S = num_chunks
|
||||||
|
M = S - 1
|
||||||
|
else:
|
||||||
|
S, M = 0, 0
|
||||||
|
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.randn(U + right_context_length, B, D)
|
||||||
|
lengths = torch.randint(1, U + right_context_length + 1, (B,))
|
||||||
|
lengths[0] = U + right_context_length
|
||||||
|
|
||||||
|
output, output_lengths = encoder(x, lengths)
|
||||||
|
assert output.shape == (U, B, D)
|
||||||
|
assert torch.equal(
|
||||||
|
output_lengths, torch.clamp(lengths - right_context_length, min=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_infer():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
num_encoder_layers = 2
|
||||||
|
chunk_length = 4
|
||||||
|
right_context_length = 2
|
||||||
|
left_context_length = 2
|
||||||
|
num_chunks = 3
|
||||||
|
kernel_size = 31
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
|
||||||
|
states = None
|
||||||
|
conv_caches = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
x = torch.randn(chunk_length + right_context_length, B, D)
|
||||||
|
lengths = torch.randint(
|
||||||
|
1, chunk_length + right_context_length + 1, (B,)
|
||||||
|
)
|
||||||
|
lengths[0] = chunk_length + right_context_length
|
||||||
|
output, output_lengths, states, conv_caches = encoder.infer(
|
||||||
|
x, lengths, states, conv_caches
|
||||||
|
)
|
||||||
|
assert output.shape == (chunk_length, B, D)
|
||||||
|
assert torch.equal(
|
||||||
|
output_lengths,
|
||||||
|
torch.clamp(lengths - right_context_length, min=0),
|
||||||
|
)
|
||||||
|
assert len(states) == num_encoder_layers
|
||||||
|
for state in states:
|
||||||
|
assert len(state) == 4
|
||||||
|
assert state[0].shape == (M, B, D)
|
||||||
|
assert state[1].shape == (left_context_length, B, D)
|
||||||
|
assert state[2].shape == (left_context_length, B, D)
|
||||||
|
assert torch.equal(
|
||||||
|
state[3],
|
||||||
|
(chunk_idx + 1) * chunk_length * torch.ones_like(state[3]),
|
||||||
|
)
|
||||||
|
for conv_cache in conv_caches:
|
||||||
|
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_forward_infer_consistency():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
chunk_length = 4
|
||||||
|
num_chunks = 3
|
||||||
|
U = chunk_length * num_chunks
|
||||||
|
left_context_length, right_context_length = 1, 2
|
||||||
|
D = 256
|
||||||
|
num_encoder_layers = 3
|
||||||
|
kernel_size = 31
|
||||||
|
memory_sizes = [0, 3]
|
||||||
|
|
||||||
|
for M in memory_sizes:
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
encoder.eval()
|
||||||
|
|
||||||
|
x = torch.randn(U + right_context_length, 1, D)
|
||||||
|
lengths = torch.tensor([U + right_context_length])
|
||||||
|
|
||||||
|
# training mode with full utterance
|
||||||
|
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||||
|
|
||||||
|
# streaming inference mode with individual chunks
|
||||||
|
states = None
|
||||||
|
conv_caches = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = x[start_idx : end_idx + right_context_length] # noqa
|
||||||
|
chunk_length = torch.tensor([chunk_length])
|
||||||
|
(
|
||||||
|
infer_output_chunk,
|
||||||
|
infer_output_lengths,
|
||||||
|
states,
|
||||||
|
conv_caches,
|
||||||
|
) = encoder.infer(chunk, chunk_length, states, conv_caches)
|
||||||
|
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||||
|
assert torch.allclose(
|
||||||
|
infer_output_chunk,
|
||||||
|
forward_output_chunk,
|
||||||
|
atol=1e-4,
|
||||||
|
rtol=0.0,
|
||||||
|
), (
|
||||||
|
infer_output_chunk - forward_output_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_forward():
|
||||||
|
from emformer import Emformer
|
||||||
|
|
||||||
|
num_features = 80
|
||||||
|
chunk_length = 16
|
||||||
|
right_context_length = 8
|
||||||
|
left_context_length = 8
|
||||||
|
num_chunks = 3
|
||||||
|
U = num_chunks * chunk_length
|
||||||
|
B, D = 2, 256
|
||||||
|
kernel_size = 31
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=D,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||||
|
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||||
|
x_lens[0] = U + right_context_length + 3
|
||||||
|
output, output_lengths = model(x, x_lens)
|
||||||
|
assert output.shape == (B, U // 4, D)
|
||||||
|
assert torch.equal(
|
||||||
|
output_lengths,
|
||||||
|
torch.clamp(
|
||||||
|
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, min=0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_infer():
|
||||||
|
from emformer import Emformer
|
||||||
|
|
||||||
|
num_features = 80
|
||||||
|
chunk_length = 8
|
||||||
|
U = chunk_length
|
||||||
|
left_context_length, right_context_length = 128, 4
|
||||||
|
B, D = 2, 256
|
||||||
|
num_chunks = 3
|
||||||
|
num_encoder_layers = 2
|
||||||
|
kernel_size = 31
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=D,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
states = None
|
||||||
|
conv_caches = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||||
|
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||||
|
x_lens[0] = U + right_context_length + 3
|
||||||
|
output, output_lengths, states, conv_caches = model.infer(
|
||||||
|
x, x_lens, states, conv_caches
|
||||||
|
)
|
||||||
|
assert output.shape == (B, U // 4, D)
|
||||||
|
assert torch.equal(
|
||||||
|
output_lengths,
|
||||||
|
torch.clamp(
|
||||||
|
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4,
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert len(states) == num_encoder_layers
|
||||||
|
for state in states:
|
||||||
|
assert len(state) == 4
|
||||||
|
assert state[0].shape == (M, B, D)
|
||||||
|
assert state[1].shape == (left_context_length // 4, B, D)
|
||||||
|
assert state[2].shape == (left_context_length // 4, B, D)
|
||||||
|
assert torch.equal(
|
||||||
|
state[3],
|
||||||
|
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
|
||||||
|
)
|
||||||
|
for conv_cache in conv_caches:
|
||||||
|
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_rel_positional_encoding()
|
test_rel_positional_encoding()
|
||||||
test_emformer_attention_forward()
|
test_emformer_attention_forward()
|
||||||
@ -282,3 +542,8 @@ if __name__ == "__main__":
|
|||||||
test_convolution_module_infer()
|
test_convolution_module_infer()
|
||||||
test_emformer_encoder_layer_forward()
|
test_emformer_encoder_layer_forward()
|
||||||
test_emformer_encoder_layer_infer()
|
test_emformer_encoder_layer_infer()
|
||||||
|
test_emformer_encoder_forward()
|
||||||
|
test_emformer_encoder_infer()
|
||||||
|
test_emformer_encoder_forward_infer_consistency()
|
||||||
|
test_emformer_forward()
|
||||||
|
test_emformer_infer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user