mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add Emformer module
This commit is contained in:
parent
b265a5c875
commit
8b60d43ead
@ -1303,7 +1303,6 @@ class EmformerEncoderLayer(nn.Module):
|
||||
output_right_context = src[:R]
|
||||
return output_utterance, output_right_context, output_memory
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
utterance: torch.Tensor,
|
||||
@ -1383,3 +1382,642 @@ class EmformerEncoderLayer(nn.Module):
|
||||
output_state,
|
||||
conv_cache,
|
||||
)
|
||||
|
||||
|
||||
def _gen_attention_mask_block(
|
||||
col_widths: List[int],
|
||||
col_mask: List[bool],
|
||||
num_rows: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
assert len(col_widths) == len(
|
||||
col_mask
|
||||
), "Length of col_widths must match that of col_mask"
|
||||
|
||||
mask_block = [
|
||||
torch.ones(num_rows, col_width, device=device)
|
||||
if is_ones_col
|
||||
else torch.zeros(num_rows, col_width, device=device)
|
||||
for col_width, is_ones_col in zip(col_widths, col_mask)
|
||||
]
|
||||
return torch.cat(mask_block, dim=1)
|
||||
|
||||
|
||||
class EmformerEncoder(nn.Module):
|
||||
"""Implements the Emformer architecture introduced in
|
||||
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency
|
||||
Streaming Speech Recognition*
|
||||
[:footcite:`shi2021emformer`].
|
||||
|
||||
Args:
|
||||
d_model (int):
|
||||
Input dimension.
|
||||
nhead (int):
|
||||
Number of attention heads in each emformer layer.
|
||||
dim_feedforward (int):
|
||||
Hidden layer dimension of each emformer layer's feedforward network.
|
||||
num_encoder_layers (int):
|
||||
Number of emformer layers to instantiate.
|
||||
chunk_length (int):
|
||||
Length of each input segment.
|
||||
dropout (float, optional):
|
||||
Dropout probability. (default: 0.0)
|
||||
layer_dropout (float, optional):
|
||||
Layer dropout probability. (default: 0.0)
|
||||
cnn_module_kernel (int):
|
||||
Kernel size of convolution module.
|
||||
left_context_length (int, optional):
|
||||
Length of left context. (default: 0)
|
||||
right_context_length (int, optional):
|
||||
Length of right context. (default: 0)
|
||||
max_memory_size (int, optional):
|
||||
Maximum number of memory elements to use. (default: 0)
|
||||
tanh_on_mem (bool, optional):
|
||||
If ``true``, applies tanh to memory elements. (default: ``false``)
|
||||
negative_inf (float, optional):
|
||||
Value to use for negative infinity in attention weights. (default: -1e8)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_length: int,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
left_context_length: int = 0,
|
||||
right_context_length: int = 0,
|
||||
max_memory_size: int = 0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_memory = max_memory_size > 0
|
||||
self.init_memory_op = nn.AvgPool1d(
|
||||
kernel_size=chunk_length,
|
||||
stride=chunk_length,
|
||||
ceil_mode=True,
|
||||
)
|
||||
|
||||
self.emformer_layers = nn.ModuleList(
|
||||
[
|
||||
EmformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
chunk_length=chunk_length,
|
||||
dropout=dropout,
|
||||
layer_dropout=layer_dropout,
|
||||
cnn_module_kernel=cnn_module_kernel,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=max_memory_size,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
)
|
||||
for layer_idx in range(num_encoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
self.left_context_length = left_context_length
|
||||
self.right_context_length = right_context_length
|
||||
self.chunk_length = chunk_length
|
||||
self.max_memory_size = max_memory_size
|
||||
|
||||
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Hard copy each chunk's right context and concat them."""
|
||||
T = x.shape[0]
|
||||
num_chunks = math.ceil(
|
||||
(T - self.right_context_length) / self.chunk_length
|
||||
)
|
||||
right_context_blocks = []
|
||||
for seg_idx in range(num_chunks - 1):
|
||||
start = (seg_idx + 1) * self.chunk_length
|
||||
end = start + self.right_context_length
|
||||
right_context_blocks.append(x[start:end])
|
||||
right_context_blocks.append(x[T - self.right_context_length :])
|
||||
return torch.cat(right_context_blocks)
|
||||
|
||||
def _gen_attention_mask_col_widths(
|
||||
self, chunk_idx: int, U: int
|
||||
) -> List[int]:
|
||||
"""Calculate column widths (key, value) in attention mask for the
|
||||
chunk_idx chunk."""
|
||||
num_chunks = math.ceil(U / self.chunk_length)
|
||||
rc = self.right_context_length
|
||||
lc = self.left_context_length
|
||||
rc_start = chunk_idx * rc
|
||||
rc_end = rc_start + rc
|
||||
chunk_start = max(chunk_idx * self.chunk_length - lc, 0)
|
||||
chunk_end = min((chunk_idx + 1) * self.chunk_length, U)
|
||||
R = rc * num_chunks
|
||||
|
||||
if self.use_memory:
|
||||
m_start = max(chunk_idx - self.max_memory_size, 0)
|
||||
M = num_chunks - 1
|
||||
col_widths = [
|
||||
m_start, # before memory
|
||||
chunk_idx - m_start, # memory
|
||||
M - chunk_idx, # after memory
|
||||
rc_start, # before right context
|
||||
rc, # right context
|
||||
R - rc_end, # after right context
|
||||
chunk_start, # before chunk
|
||||
chunk_end - chunk_start, # chunk
|
||||
U - chunk_end, # after chunk
|
||||
]
|
||||
else:
|
||||
col_widths = [
|
||||
rc_start, # before right context
|
||||
rc, # right context
|
||||
R - rc_end, # after right context
|
||||
chunk_start, # before chunk
|
||||
chunk_end - chunk_start, # chunk
|
||||
U - chunk_end, # after chunk
|
||||
]
|
||||
|
||||
return col_widths
|
||||
|
||||
def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate attention mask to simulate underlying chunk-wise attention
|
||||
computation, where chunk-wise connections are filled with `False`,
|
||||
and other unnecessary connections beyond chunk are filled with `True`.
|
||||
|
||||
R: length of hard-copied right contexts;
|
||||
U: length of full utterance;
|
||||
S: length of summary vectors;
|
||||
M: length of memory vectors;
|
||||
Q: length of attention query;
|
||||
KV: length of attention key and value.
|
||||
|
||||
The shape of attention mask is (Q, KV).
|
||||
If self.use_memory is `True`:
|
||||
query = [right_context, utterance, summary];
|
||||
key, value = [memory, right_context, utterance];
|
||||
Q = R + U + S, KV = M + R + U.
|
||||
Otherwise:
|
||||
query = [right_context, utterance]
|
||||
key, value = [right_context, utterance]
|
||||
Q = R + U, KV = R + U.
|
||||
|
||||
Suppose:
|
||||
c_i: chunk at index i;
|
||||
r_i: right context that c_i can use;
|
||||
l_i: left context that c_i can use;
|
||||
m_i: past memory vectors from previous layer that c_i can use;
|
||||
s_i: summary vector of c_i.
|
||||
The target chunk-wise attention is:
|
||||
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
|
||||
s_i (in query) -> l_i, c_i, r_i (in key).
|
||||
"""
|
||||
U = utterance.size(0)
|
||||
num_chunks = math.ceil(U / self.chunk_length)
|
||||
|
||||
right_context_mask = []
|
||||
utterance_mask = []
|
||||
summary_mask = []
|
||||
|
||||
if self.use_memory:
|
||||
num_cols = 9
|
||||
# right context and utterance both attend to memory, right context,
|
||||
# utterance
|
||||
right_context_utterance_cols_mask = [
|
||||
idx in [1, 4, 7] for idx in range(num_cols)
|
||||
]
|
||||
# summary attends to right context, utterance
|
||||
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
||||
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
|
||||
else:
|
||||
num_cols = 6
|
||||
# right context and utterance both attend to right context and
|
||||
# utterance
|
||||
right_context_utterance_cols_mask = [
|
||||
idx in [1, 4] for idx in range(num_cols)
|
||||
]
|
||||
summary_cols_mask = None
|
||||
masks_to_concat = [right_context_mask, utterance_mask]
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
col_widths = self._gen_attention_mask_col_widths(chunk_idx, U)
|
||||
|
||||
right_context_mask_block = _gen_attention_mask_block(
|
||||
col_widths,
|
||||
right_context_utterance_cols_mask,
|
||||
self.right_context_length,
|
||||
utterance.device,
|
||||
)
|
||||
right_context_mask.append(right_context_mask_block)
|
||||
|
||||
utterance_mask_block = _gen_attention_mask_block(
|
||||
col_widths,
|
||||
right_context_utterance_cols_mask,
|
||||
min(
|
||||
self.chunk_length,
|
||||
U - chunk_idx * self.chunk_length,
|
||||
),
|
||||
utterance.device,
|
||||
)
|
||||
utterance_mask.append(utterance_mask_block)
|
||||
|
||||
if summary_cols_mask is not None:
|
||||
summary_mask_block = _gen_attention_mask_block(
|
||||
col_widths, summary_cols_mask, 1, utterance.device
|
||||
)
|
||||
summary_mask.append(summary_mask_block)
|
||||
|
||||
attention_mask = (
|
||||
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
|
||||
).to(torch.bool)
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for training and validation mode.
|
||||
|
||||
B: batch size;
|
||||
D: input dimension;
|
||||
U: length of utterance.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Utterance frames right-padded with right context frames,
|
||||
with shape (U + right_context_length, B, D).
|
||||
lengths (torch.Tensor):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, which contains the
|
||||
right_context at the end.
|
||||
|
||||
Returns:
|
||||
A tuple of 2 tensors:
|
||||
- output utterance frames, with shape (U, B, D).
|
||||
- output_lengths, with shape (B,), without containing the
|
||||
right_context at the end.
|
||||
"""
|
||||
U = x.size(0) - self.right_context_length
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||
|
||||
right_context = self._gen_right_context(x)
|
||||
utterance = x[:U]
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
attention_mask = self._gen_attention_mask(utterance)
|
||||
memory = (
|
||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||
:-1
|
||||
]
|
||||
if self.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
|
||||
output = utterance
|
||||
for layer in self.emformer_layers:
|
||||
output, right_context, memory = layer(
|
||||
output,
|
||||
output_lengths,
|
||||
right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
)
|
||||
|
||||
return output, output_lengths
|
||||
|
||||
def infer(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
states: Optional[List[List[torch.Tensor]]] = None,
|
||||
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
|
||||
]:
|
||||
"""Forward pass for streaming inference.
|
||||
|
||||
B: batch size;
|
||||
D: input dimension;
|
||||
U: length of utterance.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Utterance frames right-padded with right context frames,
|
||||
with shape (U + right_context_length, B, D).
|
||||
lengths (torch.Tensor):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, which contains the
|
||||
right_context at the end.
|
||||
states (List[List[torch.Tensor]], optional):
|
||||
Cached states from proceeding chunk's computation, where each
|
||||
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||
(default: None)
|
||||
conv_caches (List[torch.Tensor], optional):
|
||||
Cached tensors of left context for causal convolution, where each
|
||||
element (Tensor) corresponds to each convolutional layer.
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
|
||||
- output utterance frames, with shape (U, B, D).
|
||||
- output lengths, with shape (B,), without containing the
|
||||
right_context at the end.
|
||||
- updated states from current chunk's computation.
|
||||
- updated convolution caches from current chunk.
|
||||
"""
|
||||
assert x.size(0) == self.chunk_length + self.right_context_length, (
|
||||
"Per configured chunk_length and right_context_length, "
|
||||
f"expected size of {self.chunk_length + self.right_context_length} "
|
||||
f"for dimension 1 of x, but got {x.size(1)}."
|
||||
)
|
||||
|
||||
pos_len = self.chunk_length + self.left_context_length
|
||||
neg_len = self.chunk_length
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||
|
||||
right_context_start_idx = x.size(0) - self.right_context_length
|
||||
right_context = x[right_context_start_idx:]
|
||||
utterance = x[:right_context_start_idx]
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
memory = (
|
||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
||||
if self.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
output = utterance
|
||||
output_states: List[List[torch.Tensor]] = []
|
||||
output_conv_caches: List[torch.Tensor] = []
|
||||
for layer_idx, layer in enumerate(self.emformer_layers):
|
||||
(
|
||||
output,
|
||||
right_context,
|
||||
memory,
|
||||
output_state,
|
||||
output_conv_cache,
|
||||
) = layer.infer(
|
||||
output,
|
||||
output_lengths,
|
||||
right_context,
|
||||
memory,
|
||||
pos_emb,
|
||||
None if states is None else states[layer_idx],
|
||||
None if conv_caches is None else conv_caches[layer_idx],
|
||||
)
|
||||
output_states.append(output_state)
|
||||
output_conv_caches.append(output_conv_cache)
|
||||
|
||||
return output, output_lengths, output_states, output_conv_caches
|
||||
|
||||
|
||||
class Emformer(EncoderInterface):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
chunk_length: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 3,
|
||||
left_context_length: int = 0,
|
||||
right_context_length: int = 0,
|
||||
max_memory_size: int = 0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.right_context_length = right_context_length
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
if chunk_length % 4 != 0:
|
||||
raise NotImplementedError("chunk_length must be a mutiple of 4.")
|
||||
if left_context_length != 0 and left_context_length % 4 != 0:
|
||||
raise NotImplementedError(
|
||||
"left_context_length must be 0 or a mutiple of 4."
|
||||
)
|
||||
if right_context_length != 0 and right_context_length % 4 != 0:
|
||||
raise NotImplementedError(
|
||||
"right_context_length must be 0 or a mutiple of 4."
|
||||
)
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length // 4,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
dropout=dropout,
|
||||
layer_dropout=layer_dropout,
|
||||
cnn_module_kernel=cnn_module_kernel,
|
||||
left_context_length=left_context_length // 4,
|
||||
right_context_length=right_context_length // 4,
|
||||
max_memory_size=max_memory_size,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for training and non-streaming inference.
|
||||
|
||||
B: batch size;
|
||||
D: feature dimension;
|
||||
T: length of utterance.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Utterance frames right-padded with right context frames,
|
||||
with shape (B, T, D).
|
||||
x_lens (torch.Tensor):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, containing the
|
||||
right_context at the end.
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor):
|
||||
- output embedding, with shape (B, T', D), where
|
||||
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
|
||||
- output lengths, with shape (B,), without containing the
|
||||
right_context at the end.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
||||
|
||||
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
|
||||
return output, output_lengths
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: Optional[List[List[torch.Tensor]]] = None,
|
||||
conv_caches: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
||||
"""Forward pass for streaming inference.
|
||||
|
||||
B: batch size;
|
||||
D: feature dimension;
|
||||
T: length of utterance.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Utterance frames right-padded with right context frames,
|
||||
with shape (B, T, D).
|
||||
lengths (torch.Tensor):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, containing the
|
||||
right_context at the end.
|
||||
states (List[List[torch.Tensor]], optional):
|
||||
Cached states from proceeding chunk's computation, where each
|
||||
element (List[torch.Tensor]) corresponds to each emformer layer.
|
||||
(default: None)
|
||||
conv_caches (List[torch.Tensor], optional):
|
||||
Cached tensors of left context for causal convolution, where each
|
||||
element (Tensor) corresponds to each convolutional layer.
|
||||
Returns:
|
||||
(Tensor, Tensor):
|
||||
- output embedding, with shape (B, T', D), where
|
||||
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
|
||||
- output lengths, with shape (B,), without containing the
|
||||
right_context at the end.
|
||||
- updated states from current chunk's computation.
|
||||
- updated convolution caches from current chunk.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
(
|
||||
output,
|
||||
output_lengths,
|
||||
output_states,
|
||||
output_conv_caches,
|
||||
) = self.encoder.infer(x, x_lens, states, conv_caches)
|
||||
|
||||
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
|
||||
return output, output_lengths, output_states, output_conv_caches
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
@ -113,7 +113,10 @@ def test_convolution_module_forward():
|
||||
R = num_chunks * right_context_length
|
||||
kernel_size = 31
|
||||
conv_module = ConvolutionModule(
|
||||
chunk_length, right_context_length, D, kernel_size,
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
D,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
@ -139,7 +142,10 @@ def test_convolution_module_infer():
|
||||
R = num_chunks * right_context_length
|
||||
kernel_size = 31
|
||||
conv_module = ConvolutionModule(
|
||||
chunk_length, right_context_length, D, kernel_size,
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
D,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
@ -274,6 +280,260 @@ def test_emformer_encoder_layer_infer():
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_forward():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
left_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
kernel_size = 31
|
||||
num_encoder_layers = 2
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S = num_chunks
|
||||
M = S - 1
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
x = torch.randn(U + right_context_length, B, D)
|
||||
lengths = torch.randint(1, U + right_context_length + 1, (B,))
|
||||
lengths[0] = U + right_context_length
|
||||
|
||||
output, output_lengths = encoder(x, lengths)
|
||||
assert output.shape == (U, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths, torch.clamp(lengths - right_context_length, min=0)
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_encoder_infer():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
B, D = 2, 256
|
||||
num_encoder_layers = 2
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
left_context_length = 2
|
||||
num_chunks = 3
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
x = torch.randn(chunk_length + right_context_length, B, D)
|
||||
lengths = torch.randint(
|
||||
1, chunk_length + right_context_length + 1, (B,)
|
||||
)
|
||||
lengths[0] = chunk_length + right_context_length
|
||||
output, output_lengths, states, conv_caches = encoder.infer(
|
||||
x, lengths, states, conv_caches
|
||||
)
|
||||
assert output.shape == (chunk_length, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(lengths - right_context_length, min=0),
|
||||
)
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
assert len(state) == 4
|
||||
assert state[0].shape == (M, B, D)
|
||||
assert state[1].shape == (left_context_length, B, D)
|
||||
assert state[2].shape == (left_context_length, B, D)
|
||||
assert torch.equal(
|
||||
state[3],
|
||||
(chunk_idx + 1) * chunk_length * torch.ones_like(state[3]),
|
||||
)
|
||||
for conv_cache in conv_caches:
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_forward_infer_consistency():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
chunk_length = 4
|
||||
num_chunks = 3
|
||||
U = chunk_length * num_chunks
|
||||
left_context_length, right_context_length = 1, 2
|
||||
D = 256
|
||||
num_encoder_layers = 3
|
||||
kernel_size = 31
|
||||
memory_sizes = [0, 3]
|
||||
|
||||
for M in memory_sizes:
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
encoder.eval()
|
||||
|
||||
x = torch.randn(U + right_context_length, 1, D)
|
||||
lengths = torch.tensor([U + right_context_length])
|
||||
|
||||
# training mode with full utterance
|
||||
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||
|
||||
# streaming inference mode with individual chunks
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[start_idx : end_idx + right_context_length] # noqa
|
||||
chunk_length = torch.tensor([chunk_length])
|
||||
(
|
||||
infer_output_chunk,
|
||||
infer_output_lengths,
|
||||
states,
|
||||
conv_caches,
|
||||
) = encoder.infer(chunk, chunk_length, states, conv_caches)
|
||||
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||
assert torch.allclose(
|
||||
infer_output_chunk,
|
||||
forward_output_chunk,
|
||||
atol=1e-4,
|
||||
rtol=0.0,
|
||||
), (
|
||||
infer_output_chunk - forward_output_chunk
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_forward():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 16
|
||||
right_context_length = 8
|
||||
left_context_length = 8
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
B, D = 2, 256
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||
x_lens[0] = U + right_context_length + 3
|
||||
output, output_lengths = model(x, x_lens)
|
||||
assert output.shape == (B, U // 4, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(
|
||||
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, min=0
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_infer():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 8
|
||||
U = chunk_length
|
||||
left_context_length, right_context_length = 128, 4
|
||||
B, D = 2, 256
|
||||
num_chunks = 3
|
||||
num_encoder_layers = 2
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
M = 3
|
||||
else:
|
||||
M = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||
x_lens[0] = U + right_context_length + 3
|
||||
output, output_lengths, states, conv_caches = model.infer(
|
||||
x, x_lens, states, conv_caches
|
||||
)
|
||||
assert output.shape == (B, U // 4, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(
|
||||
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4,
|
||||
min=0,
|
||||
),
|
||||
)
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
assert len(state) == 4
|
||||
assert state[0].shape == (M, B, D)
|
||||
assert state[1].shape == (left_context_length // 4, B, D)
|
||||
assert state[2].shape == (left_context_length // 4, B, D)
|
||||
assert torch.equal(
|
||||
state[3],
|
||||
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
|
||||
)
|
||||
for conv_cache in conv_caches:
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rel_positional_encoding()
|
||||
test_emformer_attention_forward()
|
||||
@ -282,3 +542,8 @@ if __name__ == "__main__":
|
||||
test_convolution_module_infer()
|
||||
test_emformer_encoder_layer_forward()
|
||||
test_emformer_encoder_layer_infer()
|
||||
test_emformer_encoder_forward()
|
||||
test_emformer_encoder_infer()
|
||||
test_emformer_encoder_forward_infer_consistency()
|
||||
test_emformer_forward()
|
||||
test_emformer_infer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user