add Emformer module

This commit is contained in:
yaozengwei 2022-05-14 23:18:16 +08:00
parent b265a5c875
commit 8b60d43ead
2 changed files with 906 additions and 3 deletions

View File

@ -1303,7 +1303,6 @@ class EmformerEncoderLayer(nn.Module):
output_right_context = src[:R]
return output_utterance, output_right_context, output_memory
@torch.jit.export
def infer(
self,
utterance: torch.Tensor,
@ -1383,3 +1382,642 @@ class EmformerEncoderLayer(nn.Module):
output_state,
conv_cache,
)
def _gen_attention_mask_block(
col_widths: List[int],
col_mask: List[bool],
num_rows: int,
device: torch.device,
) -> torch.Tensor:
assert len(col_widths) == len(
col_mask
), "Length of col_widths must match that of col_mask"
mask_block = [
torch.ones(num_rows, col_width, device=device)
if is_ones_col
else torch.zeros(num_rows, col_width, device=device)
for col_width, is_ones_col in zip(col_widths, col_mask)
]
return torch.cat(mask_block, dim=1)
class EmformerEncoder(nn.Module):
"""Implements the Emformer architecture introduced in
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency
Streaming Speech Recognition*
[:footcite:`shi2021emformer`].
Args:
d_model (int):
Input dimension.
nhead (int):
Number of attention heads in each emformer layer.
dim_feedforward (int):
Hidden layer dimension of each emformer layer's feedforward network.
num_encoder_layers (int):
Number of emformer layers to instantiate.
chunk_length (int):
Length of each input segment.
dropout (float, optional):
Dropout probability. (default: 0.0)
layer_dropout (float, optional):
Layer dropout probability. (default: 0.0)
cnn_module_kernel (int):
Kernel size of convolution module.
left_context_length (int, optional):
Length of left context. (default: 0)
right_context_length (int, optional):
Length of right context. (default: 0)
max_memory_size (int, optional):
Maximum number of memory elements to use. (default: 0)
tanh_on_mem (bool, optional):
If ``true``, applies tanh to memory elements. (default: ``false``)
negative_inf (float, optional):
Value to use for negative infinity in attention weights. (default: -1e8)
"""
def __init__(
self,
chunk_length: int,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
left_context_length: int = 0,
right_context_length: int = 0,
max_memory_size: int = 0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
):
super().__init__()
self.use_memory = max_memory_size > 0
self.init_memory_op = nn.AvgPool1d(
kernel_size=chunk_length,
stride=chunk_length,
ceil_mode=True,
)
self.emformer_layers = nn.ModuleList(
[
EmformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
chunk_length=chunk_length,
dropout=dropout,
layer_dropout=layer_dropout,
cnn_module_kernel=cnn_module_kernel,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
)
for layer_idx in range(num_encoder_layers)
]
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.left_context_length = left_context_length
self.right_context_length = right_context_length
self.chunk_length = chunk_length
self.max_memory_size = max_memory_size
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
"""Hard copy each chunk's right context and concat them."""
T = x.shape[0]
num_chunks = math.ceil(
(T - self.right_context_length) / self.chunk_length
)
right_context_blocks = []
for seg_idx in range(num_chunks - 1):
start = (seg_idx + 1) * self.chunk_length
end = start + self.right_context_length
right_context_blocks.append(x[start:end])
right_context_blocks.append(x[T - self.right_context_length :])
return torch.cat(right_context_blocks)
def _gen_attention_mask_col_widths(
self, chunk_idx: int, U: int
) -> List[int]:
"""Calculate column widths (key, value) in attention mask for the
chunk_idx chunk."""
num_chunks = math.ceil(U / self.chunk_length)
rc = self.right_context_length
lc = self.left_context_length
rc_start = chunk_idx * rc
rc_end = rc_start + rc
chunk_start = max(chunk_idx * self.chunk_length - lc, 0)
chunk_end = min((chunk_idx + 1) * self.chunk_length, U)
R = rc * num_chunks
if self.use_memory:
m_start = max(chunk_idx - self.max_memory_size, 0)
M = num_chunks - 1
col_widths = [
m_start, # before memory
chunk_idx - m_start, # memory
M - chunk_idx, # after memory
rc_start, # before right context
rc, # right context
R - rc_end, # after right context
chunk_start, # before chunk
chunk_end - chunk_start, # chunk
U - chunk_end, # after chunk
]
else:
col_widths = [
rc_start, # before right context
rc, # right context
R - rc_end, # after right context
chunk_start, # before chunk
chunk_end - chunk_start, # chunk
U - chunk_end, # after chunk
]
return col_widths
def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor:
"""Generate attention mask to simulate underlying chunk-wise attention
computation, where chunk-wise connections are filled with `False`,
and other unnecessary connections beyond chunk are filled with `True`.
R: length of hard-copied right contexts;
U: length of full utterance;
S: length of summary vectors;
M: length of memory vectors;
Q: length of attention query;
KV: length of attention key and value.
The shape of attention mask is (Q, KV).
If self.use_memory is `True`:
query = [right_context, utterance, summary];
key, value = [memory, right_context, utterance];
Q = R + U + S, KV = M + R + U.
Otherwise:
query = [right_context, utterance]
key, value = [right_context, utterance]
Q = R + U, KV = R + U.
Suppose:
c_i: chunk at index i;
r_i: right context that c_i can use;
l_i: left context that c_i can use;
m_i: past memory vectors from previous layer that c_i can use;
s_i: summary vector of c_i.
The target chunk-wise attention is:
c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key);
s_i (in query) -> l_i, c_i, r_i (in key).
"""
U = utterance.size(0)
num_chunks = math.ceil(U / self.chunk_length)
right_context_mask = []
utterance_mask = []
summary_mask = []
if self.use_memory:
num_cols = 9
# right context and utterance both attend to memory, right context,
# utterance
right_context_utterance_cols_mask = [
idx in [1, 4, 7] for idx in range(num_cols)
]
# summary attends to right context, utterance
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
else:
num_cols = 6
# right context and utterance both attend to right context and
# utterance
right_context_utterance_cols_mask = [
idx in [1, 4] for idx in range(num_cols)
]
summary_cols_mask = None
masks_to_concat = [right_context_mask, utterance_mask]
for chunk_idx in range(num_chunks):
col_widths = self._gen_attention_mask_col_widths(chunk_idx, U)
right_context_mask_block = _gen_attention_mask_block(
col_widths,
right_context_utterance_cols_mask,
self.right_context_length,
utterance.device,
)
right_context_mask.append(right_context_mask_block)
utterance_mask_block = _gen_attention_mask_block(
col_widths,
right_context_utterance_cols_mask,
min(
self.chunk_length,
U - chunk_idx * self.chunk_length,
),
utterance.device,
)
utterance_mask.append(utterance_mask_block)
if summary_cols_mask is not None:
summary_mask_block = _gen_attention_mask_block(
col_widths, summary_cols_mask, 1, utterance.device
)
summary_mask.append(summary_mask_block)
attention_mask = (
1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
).to(torch.bool)
return attention_mask
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training and validation mode.
B: batch size;
D: input dimension;
U: length of utterance.
Args:
x (torch.Tensor):
Utterance frames right-padded with right context frames,
with shape (U + right_context_length, B, D).
lengths (torch.Tensor):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the
right_context at the end.
Returns:
A tuple of 2 tensors:
- output utterance frames, with shape (U, B, D).
- output_lengths, with shape (B,), without containing the
right_context at the end.
"""
U = x.size(0) - self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
right_context = self._gen_right_context(x)
utterance = x[:U]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance)
memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
:-1
]
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
output = utterance
for layer in self.emformer_layers:
output, right_context, memory = layer(
output,
output_lengths,
right_context,
memory,
attention_mask,
pos_emb,
)
return output, output_lengths
def infer(
self,
x: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
conv_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[
torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]
]:
"""Forward pass for streaming inference.
B: batch size;
D: input dimension;
U: length of utterance.
Args:
x (torch.Tensor):
Utterance frames right-padded with right context frames,
with shape (U + right_context_length, B, D).
lengths (torch.Tensor):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the
right_context at the end.
states (List[List[torch.Tensor]], optional):
Cached states from proceeding chunk's computation, where each
element (List[torch.Tensor]) corresponds to each emformer layer.
(default: None)
conv_caches (List[torch.Tensor], optional):
Cached tensors of left context for causal convolution, where each
element (Tensor) corresponds to each convolutional layer.
Returns:
(Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
- output utterance frames, with shape (U, B, D).
- output lengths, with shape (B,), without containing the
right_context at the end.
- updated states from current chunk's computation.
- updated convolution caches from current chunk.
"""
assert x.size(0) == self.chunk_length + self.right_context_length, (
"Per configured chunk_length and right_context_length, "
f"expected size of {self.chunk_length + self.right_context_length} "
f"for dimension 1 of x, but got {x.size(1)}."
)
pos_len = self.chunk_length + self.left_context_length
neg_len = self.chunk_length
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
right_context_start_idx = x.size(0) - self.right_context_length
right_context = x[right_context_start_idx:]
utterance = x[:right_context_start_idx]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
output = utterance
output_states: List[List[torch.Tensor]] = []
output_conv_caches: List[torch.Tensor] = []
for layer_idx, layer in enumerate(self.emformer_layers):
(
output,
right_context,
memory,
output_state,
output_conv_cache,
) = layer.infer(
output,
output_lengths,
right_context,
memory,
pos_emb,
None if states is None else states[layer_idx],
None if conv_caches is None else conv_caches[layer_idx],
)
output_states.append(output_state)
output_conv_caches.append(output_conv_cache)
return output, output_lengths, output_states, output_conv_caches
class Emformer(EncoderInterface):
def __init__(
self,
num_features: int,
chunk_length: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 3,
left_context_length: int = 0,
right_context_length: int = 0,
max_memory_size: int = 0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
):
super().__init__()
self.subsampling_factor = subsampling_factor
self.right_context_length = right_context_length
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
if chunk_length % 4 != 0:
raise NotImplementedError("chunk_length must be a mutiple of 4.")
if left_context_length != 0 and left_context_length % 4 != 0:
raise NotImplementedError(
"left_context_length must be 0 or a mutiple of 4."
)
if right_context_length != 0 and right_context_length % 4 != 0:
raise NotImplementedError(
"right_context_length must be 0 or a mutiple of 4."
)
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder = EmformerEncoder(
chunk_length=chunk_length // 4,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
dropout=dropout,
layer_dropout=layer_dropout,
cnn_module_kernel=cnn_module_kernel,
left_context_length=left_context_length // 4,
right_context_length=right_context_length // 4,
max_memory_size=max_memory_size,
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training and non-streaming inference.
B: batch size;
D: feature dimension;
T: length of utterance.
Args:
x (torch.Tensor):
Utterance frames right-padded with right context frames,
with shape (B, T, D).
x_lens (torch.Tensor):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, containing the
right_context at the end.
Returns:
(Tensor, Tensor):
- output embedding, with shape (B, T', D), where
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
- output lengths, with shape (B,), without containing the
right_context at the end.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return output, output_lengths
@torch.jit.export
def infer(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
conv_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
"""Forward pass for streaming inference.
B: batch size;
D: feature dimension;
T: length of utterance.
Args:
x (torch.Tensor):
Utterance frames right-padded with right context frames,
with shape (B, T, D).
lengths (torch.Tensor):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, containing the
right_context at the end.
states (List[List[torch.Tensor]], optional):
Cached states from proceeding chunk's computation, where each
element (List[torch.Tensor]) corresponds to each emformer layer.
(default: None)
conv_caches (List[torch.Tensor], optional):
Cached tensors of left context for causal convolution, where each
element (Tensor) corresponds to each convolutional layer.
Returns:
(Tensor, Tensor):
- output embedding, with shape (B, T', D), where
T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
- output lengths, with shape (B,), without containing the
right_context at the end.
- updated states from current chunk's computation.
- updated convolution caches from current chunk.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item()
(
output,
output_lengths,
output_states,
output_conv_caches,
) = self.encoder.infer(x, x_lens, states, conv_caches)
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return output, output_lengths, output_states, output_conv_caches
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
"""
assert in_channels >= 7
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=1,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
return x

View File

@ -113,7 +113,10 @@ def test_convolution_module_forward():
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length, right_context_length, D, kernel_size,
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
@ -139,7 +142,10 @@ def test_convolution_module_infer():
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length, right_context_length, D, kernel_size,
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
@ -274,6 +280,260 @@ def test_emformer_encoder_layer_infer():
assert conv_cache.shape == (B, D, kernel_size - 1)
def test_emformer_encoder_forward():
from emformer import EmformerEncoder
B, D = 2, 256
chunk_length = 4
right_context_length = 2
left_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
kernel_size = 31
num_encoder_layers = 2
for use_memory in [True, False]:
if use_memory:
S = num_chunks
M = S - 1
else:
S, M = 0, 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
x = torch.randn(U + right_context_length, B, D)
lengths = torch.randint(1, U + right_context_length + 1, (B,))
lengths[0] = U + right_context_length
output, output_lengths = encoder(x, lengths)
assert output.shape == (U, B, D)
assert torch.equal(
output_lengths, torch.clamp(lengths - right_context_length, min=0)
)
def test_emformer_encoder_infer():
from emformer import EmformerEncoder
B, D = 2, 256
num_encoder_layers = 2
chunk_length = 4
right_context_length = 2
left_context_length = 2
num_chunks = 3
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
x = torch.randn(chunk_length + right_context_length, B, D)
lengths = torch.randint(
1, chunk_length + right_context_length + 1, (B,)
)
lengths[0] = chunk_length + right_context_length
output, output_lengths, states, conv_caches = encoder.infer(
x, lengths, states, conv_caches
)
assert output.shape == (chunk_length, B, D)
assert torch.equal(
output_lengths,
torch.clamp(lengths - right_context_length, min=0),
)
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (M, B, D)
assert state[1].shape == (left_context_length, B, D)
assert state[2].shape == (left_context_length, B, D)
assert torch.equal(
state[3],
(chunk_idx + 1) * chunk_length * torch.ones_like(state[3]),
)
for conv_cache in conv_caches:
assert conv_cache.shape == (B, D, kernel_size - 1)
def test_emformer_encoder_forward_infer_consistency():
from emformer import EmformerEncoder
chunk_length = 4
num_chunks = 3
U = chunk_length * num_chunks
left_context_length, right_context_length = 1, 2
D = 256
num_encoder_layers = 3
kernel_size = 31
memory_sizes = [0, 3]
for M in memory_sizes:
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
encoder.eval()
x = torch.randn(U + right_context_length, 1, D)
lengths = torch.tensor([U + right_context_length])
# training mode with full utterance
forward_output, forward_output_lengths = encoder(x, lengths)
# streaming inference mode with individual chunks
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[start_idx : end_idx + right_context_length] # noqa
chunk_length = torch.tensor([chunk_length])
(
infer_output_chunk,
infer_output_lengths,
states,
conv_caches,
) = encoder.infer(chunk, chunk_length, states, conv_caches)
forward_output_chunk = forward_output[start_idx:end_idx]
assert torch.allclose(
infer_output_chunk,
forward_output_chunk,
atol=1e-4,
rtol=0.0,
), (
infer_output_chunk - forward_output_chunk
)
def test_emformer_forward():
from emformer import Emformer
num_features = 80
chunk_length = 16
right_context_length = 8
left_context_length = 8
num_chunks = 3
U = num_chunks * chunk_length
B, D = 2, 256
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
x_lens[0] = U + right_context_length + 3
output, output_lengths = model(x, x_lens)
assert output.shape == (B, U // 4, D)
assert torch.equal(
output_lengths,
torch.clamp(
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, min=0
),
)
def test_emformer_infer():
from emformer import Emformer
num_features = 80
chunk_length = 8
U = chunk_length
left_context_length, right_context_length = 128, 4
B, D = 2, 256
num_chunks = 3
num_encoder_layers = 2
kernel_size = 31
for use_memory in [True, False]:
if use_memory:
M = 3
else:
M = 0
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
x_lens[0] = U + right_context_length + 3
output, output_lengths, states, conv_caches = model.infer(
x, x_lens, states, conv_caches
)
assert output.shape == (B, U // 4, D)
assert torch.equal(
output_lengths,
torch.clamp(
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4,
min=0,
),
)
assert len(states) == num_encoder_layers
for state in states:
assert len(state) == 4
assert state[0].shape == (M, B, D)
assert state[1].shape == (left_context_length // 4, B, D)
assert state[2].shape == (left_context_length // 4, B, D)
assert torch.equal(
state[3],
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
)
for conv_cache in conv_caches:
assert conv_cache.shape == (B, D, kernel_size - 1)
if __name__ == "__main__":
test_rel_positional_encoding()
test_emformer_attention_forward()
@ -282,3 +542,8 @@ if __name__ == "__main__":
test_convolution_module_infer()
test_emformer_encoder_layer_forward()
test_emformer_encoder_layer_infer()
test_emformer_encoder_forward()
test_emformer_encoder_infer()
test_emformer_encoder_forward_infer_consistency()
test_emformer_forward()
test_emformer_infer()