mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
Update emformer_pruned_transducer_stateless/emformer.py and upload emformer_pruned_transducer_stateless/test_emformer.py.
This commit is contained in:
parent
fe43c1349e
commit
9423b3899f
@ -9,48 +9,6 @@ from encoder_interface import EncoderInterface
|
|||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
|
|
||||||
|
|
||||||
def _gen_padding_mask(
|
|
||||||
utterance: torch.Tensor,
|
|
||||||
right_context: torch.Tensor,
|
|
||||||
lengths: torch.Tensor,
|
|
||||||
mems: torch.Tensor,
|
|
||||||
left_context_key: Optional[torch.Tensor] = None,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""Generate padding mask according to the length of the tensors
|
|
||||||
contained in the key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
utterance: (U, B, D)
|
|
||||||
right_context: (R, B, D)
|
|
||||||
lengths: (B,)
|
|
||||||
mems: (M, B, D)
|
|
||||||
left_context_key: (L, B, D)
|
|
||||||
B is the batch size, D is the feature dimension,
|
|
||||||
U is the length of the utterance,
|
|
||||||
R is the length of the right context block,
|
|
||||||
M is the length of the memory block,
|
|
||||||
L is the length of the left context block
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
padding_mask:
|
|
||||||
Padding mask for the concatenated key tensor
|
|
||||||
[mems, right_context, left_context, utterance],
|
|
||||||
sharing for all queries, with shape of (M + R + L + U, B)
|
|
||||||
"""
|
|
||||||
assert utterance.size(0) == torch.max(lengths)
|
|
||||||
B = utterance.size(1)
|
|
||||||
M = mems.size(0)
|
|
||||||
R = right_context.size(0)
|
|
||||||
L = left_context_key.size(0) if left_context_key is not None else 0
|
|
||||||
if B == 1:
|
|
||||||
# TODO: for infer mode?
|
|
||||||
padding_mask = None
|
|
||||||
else:
|
|
||||||
lengths_concat = M + R + L + lengths
|
|
||||||
padding_mask = make_pad_mask(lengths_concat)
|
|
||||||
return padding_mask
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_module(activation: str) -> nn.Module:
|
def _get_activation_module(activation: str) -> nn.Module:
|
||||||
if activation == "relu":
|
if activation == "relu":
|
||||||
return nn.ReLU()
|
return nn.ReLU()
|
||||||
@ -96,11 +54,6 @@ def _gen_attention_mask_block(
|
|||||||
return torch.cat(mask_block, dim=1)
|
return torch.cat(mask_block, dim=1)
|
||||||
|
|
||||||
|
|
||||||
def length_down_sampling(length):
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
|
||||||
return ((length - 1) // 2 - 1) // 2
|
|
||||||
|
|
||||||
|
|
||||||
class EmformerAttention(nn.Module):
|
class EmformerAttention(nn.Module):
|
||||||
r"""Emformer layer attention module.
|
r"""Emformer layer attention module.
|
||||||
|
|
||||||
@ -239,7 +192,7 @@ class EmformerAttention(nn.Module):
|
|||||||
and compute query tensor with length Q = R + U + S.
|
and compute query tensor with length Q = R + U + S.
|
||||||
2) Concat memory, right_context, utterance,
|
2) Concat memory, right_context, utterance,
|
||||||
and compute key, value tensors with length KV = M + R + U;
|
and compute key, value tensors with length KV = M + R + U;
|
||||||
optionally with left_context_key and left_context_val (inference mode)
|
optionally with left_context_key and left_context_val (inference mode),
|
||||||
then KV = M + R + L + U.
|
then KV = M + R + L + U.
|
||||||
3) Compute entire attention scores with query, key, and value,
|
3) Compute entire attention scores with query, key, and value,
|
||||||
then apply attention_mask to get underlying chunk-wise attention scores.
|
then apply attention_mask to get underlying chunk-wise attention scores.
|
||||||
@ -284,7 +237,7 @@ class EmformerAttention(nn.Module):
|
|||||||
).chunk(chunks=2, dim=2)
|
).chunk(chunks=2, dim=2)
|
||||||
|
|
||||||
if left_context_key is not None and left_context_val is not None:
|
if left_context_key is not None and left_context_val is not None:
|
||||||
# Now compute key and value with
|
# This is for inference mode. Now compute key and value with
|
||||||
# [mems, right context, left context, uttrance]
|
# [mems, right context, left context, uttrance]
|
||||||
M = memory.size(0)
|
M = memory.size(0)
|
||||||
R = right_context.size(0)
|
R = right_context.size(0)
|
||||||
@ -328,8 +281,8 @@ class EmformerAttention(nn.Module):
|
|||||||
outputs = self.out_proj(attention)
|
outputs = self.out_proj(attention)
|
||||||
|
|
||||||
S = summary.size(0)
|
S = summary.size(0)
|
||||||
output_right_context_utterance = outputs[:-S]
|
output_right_context_utterance = outputs[:Q - S]
|
||||||
output_memory = outputs[-S:]
|
output_memory = outputs[Q - S:]
|
||||||
if self.tanh_on_mem:
|
if self.tanh_on_mem:
|
||||||
output_memory = torch.tanh(output_memory)
|
output_memory = torch.tanh(output_memory)
|
||||||
else:
|
else:
|
||||||
@ -370,12 +323,12 @@ class EmformerAttention(nn.Module):
|
|||||||
Memory elements, with shape (M, B, D).
|
Memory elements, with shape (M, B, D).
|
||||||
attention_mask (torch.Tensor):
|
attention_mask (torch.Tensor):
|
||||||
Attention mask for underlying chunk-wise attention,
|
Attention mask for underlying chunk-wise attention,
|
||||||
with shape (Q, KV).
|
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 2 tensors:
|
A tuple containing 2 tensors:
|
||||||
- output of right context and utterance, with shape (R + U, B, D).
|
- output of right context and utterance, with shape (R + U, B, D).
|
||||||
- memory output, with shape (M, B, D), where M = S - 1.
|
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
||||||
"""
|
"""
|
||||||
output_right_context_utterance, output_memory, _, _ = \
|
output_right_context_utterance, output_memory, _, _ = \
|
||||||
self._forward_impl(
|
self._forward_impl(
|
||||||
@ -418,7 +371,7 @@ class EmformerAttention(nn.Module):
|
|||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
summary (torch.Tensor):
|
summary (torch.Tensor):
|
||||||
Summary elements, with shape (S, B, D).
|
Summary element, with shape (1, B, D), or empty.
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory elements, with shape (M, B, D).
|
Memory elements, with shape (M, B, D).
|
||||||
left_context_key (torch,Tensor):
|
left_context_key (torch,Tensor):
|
||||||
@ -431,7 +384,7 @@ class EmformerAttention(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple containing 4 tensors:
|
A tuple containing 4 tensors:
|
||||||
- output of right context and utterance, with shape (R + U, B, D).
|
- output of right context and utterance, with shape (R + U, B, D).
|
||||||
- memory output, with shape (S, B, D).
|
- memory output, with shape (1, B, D) or (0, B, D).
|
||||||
- attention key of left context and utterance, which would be cached
|
- attention key of left context and utterance, which would be cached
|
||||||
for next computation, with shape (L + U, B, D).
|
for next computation, with shape (L + U, B, D).
|
||||||
- attention value of left context and utterance, which would be
|
- attention value of left context and utterance, which would be
|
||||||
@ -476,7 +429,7 @@ class EmformerLayer(nn.Module):
|
|||||||
Number of attention heads.
|
Number of attention heads.
|
||||||
dim_feedforward (int):
|
dim_feedforward (int):
|
||||||
Hidden layer dimension of feedforward network.
|
Hidden layer dimension of feedforward network.
|
||||||
segment_length (int):
|
chunk_length (int):
|
||||||
Length of each input segment.
|
Length of each input segment.
|
||||||
dropout (float, optional):
|
dropout (float, optional):
|
||||||
Dropout probability. (Default: 0.0)
|
Dropout probability. (Default: 0.0)
|
||||||
@ -501,7 +454,7 @@ class EmformerLayer(nn.Module):
|
|||||||
d_model: int,
|
d_model: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
dim_feedforward: int,
|
dim_feedforward: int,
|
||||||
segment_length: int,
|
chunk_length: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
activation: str = "relu",
|
activation: str = "relu",
|
||||||
left_context_length: int = 0,
|
left_context_length: int = 0,
|
||||||
@ -513,7 +466,7 @@ class EmformerLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attention = EmformerAttention(
|
self.attention = EmformerAttention(
|
||||||
d_model=d_model,
|
embed_dim=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
weight_init_gain=weight_init_gain,
|
weight_init_gain=weight_init_gain,
|
||||||
@ -522,7 +475,7 @@ class EmformerLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.summary_op = nn.AvgPool1d(
|
self.summary_op = nn.AvgPool1d(
|
||||||
kernel_size=segment_length, stride=segment_length, ceil_mode=True
|
kernel_size=chunk_length, stride=chunk_length, ceil_mode=True
|
||||||
)
|
)
|
||||||
|
|
||||||
activation_module = _get_activation_module(activation)
|
activation_module = _get_activation_module(activation)
|
||||||
@ -538,7 +491,7 @@ class EmformerLayer(nn.Module):
|
|||||||
self.layer_norm_output = nn.LayerNorm(d_model)
|
self.layer_norm_output = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
self.left_context_length = left_context_length
|
self.left_context_length = left_context_length
|
||||||
self.segment_length = segment_length
|
self.chunk_length = chunk_length
|
||||||
self.max_memory_size = max_memory_size
|
self.max_memory_size = max_memory_size
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|
||||||
@ -576,11 +529,13 @@ class EmformerLayer(nn.Module):
|
|||||||
past_length = state[3][0][0].item()
|
past_length = state[3][0][0].item()
|
||||||
past_left_context_length = min(self.left_context_length, past_length)
|
past_left_context_length = min(self.left_context_length, past_length)
|
||||||
past_memory_length = min(
|
past_memory_length = min(
|
||||||
self.max_memory_size, math.ceil(past_length / self.segment_length)
|
self.max_memory_size, math.ceil(past_length / self.chunk_length)
|
||||||
)
|
)
|
||||||
pre_memory = state[0][-past_memory_length:]
|
pre_memory = state[0][self.max_memory_size - past_memory_length:]
|
||||||
left_context_key = state[1][-past_left_context_length:]
|
left_context_key = \
|
||||||
left_context_val = state[2][-past_left_context_length:]
|
state[1][self.left_context_length - past_left_context_length:]
|
||||||
|
left_context_val = \
|
||||||
|
state[2][self.left_context_length - past_left_context_length:]
|
||||||
return pre_memory, left_context_key, left_context_val
|
return pre_memory, left_context_key, left_context_val
|
||||||
|
|
||||||
def _pack_state(
|
def _pack_state(
|
||||||
@ -600,9 +555,9 @@ class EmformerLayer(nn.Module):
|
|||||||
new_memory = torch.cat([state[0], memory])
|
new_memory = torch.cat([state[0], memory])
|
||||||
new_key = torch.cat([state[1], next_key])
|
new_key = torch.cat([state[1], next_key])
|
||||||
new_val = torch.cat([state[2], next_val])
|
new_val = torch.cat([state[2], next_val])
|
||||||
state[0] = new_memory[-self.max_memory_size:]
|
state[0] = new_memory[new_memory.size(0) - self.max_memory_size:]
|
||||||
state[1] = new_key[-self.left_context_length:]
|
state[1] = new_key[new_key.size(0) - self.left_context_length:]
|
||||||
state[2] = new_val[-self.left_context_length:]
|
state[2] = new_val[new_val.size(0) - self.left_context_length:]
|
||||||
state[3] = state[3] + update_length
|
state[3] = state[3] + update_length
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@ -749,7 +704,8 @@ class EmformerLayer(nn.Module):
|
|||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory elements, with shape (M, B, D).
|
Memory elements, with shape (M, B, D).
|
||||||
attention_mask (torch.Tensor):
|
attention_mask (torch.Tensor):
|
||||||
Attention mask for underlying attention module.
|
Attention mask for underlying attention module,
|
||||||
|
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 3 tensors:
|
A tuple containing 3 tensors:
|
||||||
@ -819,7 +775,7 @@ class EmformerLayer(nn.Module):
|
|||||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||||
- output utterance, with shape (U, B, D);
|
- output utterance, with shape (U, B, D);
|
||||||
- output right_context, with shape (R, B, D);
|
- output right_context, with shape (R, B, D);
|
||||||
- output memory, with shape (M, B, D);
|
- output memory, with shape (1, B, D) or (0, B, D).
|
||||||
- output state.
|
- output state.
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
@ -883,15 +839,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
If ``true``, applies tanh to memory elements. (default: ``false``)
|
If ``true``, applies tanh to memory elements. (default: ``false``)
|
||||||
negative_inf (float, optional):
|
negative_inf (float, optional):
|
||||||
Value to use for negative infinity in attention weights. (default: -1e8)
|
Value to use for negative infinity in attention weights. (default: -1e8)
|
||||||
|
|
||||||
examples:
|
|
||||||
>>> emformer = emformer(512, 8, 2048, 20, 4, right_context_length=1)
|
|
||||||
>>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
|
|
||||||
>>> lengths = torch.randint(1, 200, (128,)) # batch
|
|
||||||
>>> output = emformer(input, lengths)
|
|
||||||
>>> input = torch.rand(128, 5, 512)
|
|
||||||
>>> lengths = torch.ones(128) * 5
|
|
||||||
>>> output, lengths, states = emformer.infer(input, lengths, None)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -913,7 +860,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.use_memory = max_memory_size > 0
|
self.use_memory = max_memory_size > 0
|
||||||
self.memory_op = nn.AvgPool1d(
|
self.init_memory_op = nn.AvgPool1d(
|
||||||
kernel_size=chunk_length,
|
kernel_size=chunk_length,
|
||||||
stride=chunk_length,
|
stride=chunk_length,
|
||||||
ceil_mode=True,
|
ceil_mode=True,
|
||||||
@ -957,7 +904,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
start = (seg_idx + 1) * self.chunk_length
|
start = (seg_idx + 1) * self.chunk_length
|
||||||
end = start + self.right_context_length
|
end = start + self.right_context_length
|
||||||
right_context_blocks.append(x[start:end])
|
right_context_blocks.append(x[start:end])
|
||||||
right_context_blocks.append(x[-self.right_context_length:])
|
right_context_blocks.append(x[T - self.right_context_length:])
|
||||||
return torch.cat(right_context_blocks)
|
return torch.cat(right_context_blocks)
|
||||||
|
|
||||||
def _gen_attention_mask_col_widths(
|
def _gen_attention_mask_col_widths(
|
||||||
@ -1095,31 +1042,34 @@ class EmformerEncoder(nn.Module):
|
|||||||
with shape (U + right_context_length, B, D).
|
with shape (U + right_context_length, B, D).
|
||||||
lengths (torch.Tensor):
|
lengths (torch.Tensor):
|
||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x.
|
utterance frames for i-th batch element in x, which contains the
|
||||||
It is the true lengths without containing the right_context.
|
right_context at the end.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor):
|
A tuple of 2 tensors:
|
||||||
- output utterance frames, with shape (U, B, D).
|
- output utterance frames, with shape (U, B, D).
|
||||||
- output lengths, with shape (B,) and i-th element representing
|
- output_lengths, with shape (B,), without containing the
|
||||||
number of valid frames for i-th batch element in output frames.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
assert x.size(0) == torch.max(lengths).item() + \
|
# assert x.size(0) == torch.max(lengths).item()
|
||||||
self.right_context_length
|
|
||||||
right_context = self._gen_right_context(x)
|
right_context = self._gen_right_context(x)
|
||||||
utterance = x[:-self.right_context_length]
|
utterance = x[:x.size(0) - self.right_context_length]
|
||||||
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
attention_mask = self._gen_attention_mask(utterance)
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
memory = (
|
memory = (
|
||||||
self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
|
self.init_memory_op(
|
||||||
|
utterance.permute(1, 2, 0)
|
||||||
|
).permute(2, 0, 1)[:-1]
|
||||||
if self.use_memory
|
if self.use_memory
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
)
|
)
|
||||||
output = utterance
|
output = utterance
|
||||||
for layer in self.emformer_layers:
|
for layer in self.emformer_layers:
|
||||||
output, right_context, memory = \
|
output, right_context, memory = layer(
|
||||||
layer(output, lengths, right_context, memory, attention_mask)
|
output, output_lengths, right_context, memory, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
return output, lengths
|
return output, output_lengths
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def infer(
|
def infer(
|
||||||
@ -1137,11 +1087,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x (torch.Tensor):
|
x (torch.Tensor):
|
||||||
Utterance frames right-padded with right context frames,
|
Utterance frames right-padded with right context frames,
|
||||||
with shape (chunk_length + right_context_length, B, D).
|
with shape (U + right_context_length, B, D).
|
||||||
lengths (torch.Tensor):
|
lengths (torch.Tensor):
|
||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x.
|
utterance frames for i-th batch element in x, which contains the
|
||||||
It contains the right_context.
|
right_context at the end.
|
||||||
states (List[List[torch.Tensor]], optional):
|
states (List[List[torch.Tensor]], optional):
|
||||||
Cached states from proceeding chunk's computation, where each
|
Cached states from proceeding chunk's computation, where each
|
||||||
element (List[torch.Tensor]) corresponding to each emformer layer.
|
element (List[torch.Tensor]) corresponding to each emformer layer.
|
||||||
@ -1150,8 +1100,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor, List[List[torch.Tensor]]):
|
(Tensor, Tensor, List[List[torch.Tensor]]):
|
||||||
- output utterance frames, with shape (U, B, D).
|
- output utterance frames, with shape (U, B, D).
|
||||||
- output lengths, with shape (B,) and i-th element representing
|
- output lengths, with shape (B,), without containing the
|
||||||
number of valid frames for i-th batch element in output frames.
|
right_context at the end.
|
||||||
- updated states from current chunk's computation.
|
- updated states from current chunk's computation.
|
||||||
"""
|
"""
|
||||||
assert x.size(0) == self.chunk_length + self.right_context_length, (
|
assert x.size(0) == self.chunk_length + self.right_context_length, (
|
||||||
@ -1159,23 +1109,24 @@ class EmformerEncoder(nn.Module):
|
|||||||
f"expected size of {self.chunk_length + self.right_context_length} "
|
f"expected size of {self.chunk_length + self.right_context_length} "
|
||||||
f"for dimension 1 of x, but got {x.size(1)}."
|
f"for dimension 1 of x, but got {x.size(1)}."
|
||||||
)
|
)
|
||||||
right_context = x[-self.right_context_length:]
|
right_context_start_idx = x.size(0) - self.right_context_length
|
||||||
utterance = x[:-self.right_context_length]
|
right_context = x[right_context_start_idx:]
|
||||||
|
utterance = x[:right_context_start_idx]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
memory = (
|
memory = (
|
||||||
self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
||||||
if self.use_memory
|
if self.use_memory
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
)
|
)
|
||||||
output = utterance
|
output = utterance
|
||||||
output_states: List[List[torch.Tensor]] = []
|
output_states: List[List[torch.Tensor]] = []
|
||||||
for layer_idx, layer in enumerate(self.emformer_layers):
|
for layer_idx, layer in enumerate(self.emformer_layers):
|
||||||
output, right_context, output_state, memory = layer.infer(
|
output, right_context, memory, output_state = layer.infer(
|
||||||
output,
|
output,
|
||||||
output_lengths,
|
output_lengths,
|
||||||
right_context,
|
right_context,
|
||||||
None if states is None else states[layer_idx],
|
|
||||||
memory,
|
memory,
|
||||||
|
None if states is None else states[layer_idx],
|
||||||
)
|
)
|
||||||
output_states.append(output_state)
|
output_states.append(output_state)
|
||||||
|
|
||||||
@ -1272,24 +1223,23 @@ class Emformer(EncoderInterface):
|
|||||||
with shape (B, U + right_context_length, D).
|
with shape (B, U + right_context_length, D).
|
||||||
x_lens (torch.Tensor):
|
x_lens (torch.Tensor):
|
||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x.
|
utterance frames for i-th batch element in x, containing the
|
||||||
It is the true lengths without containing the right_context.
|
right_context at the end.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor):
|
(Tensor, Tensor):
|
||||||
- output logits, with shape (B, U // 4, D).
|
- output logits, with shape (B, U // 4, D).
|
||||||
- logits lengths, with shape (B,) and i-th element representing
|
- logits lengths, with shape (B,), without containing the
|
||||||
number of valid frames for i-th batch element in output frames.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
|
# TODO: x.shape
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = self.encoder_pos(x)
|
x = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
lengths = x_lens // 4
|
lengths = x_lens // 4
|
||||||
assert x.size(0) == \
|
assert x.size(0) == lengths.max().item()
|
||||||
lengths.max().item() + self.right_context_length // 4
|
|
||||||
|
|
||||||
output, output_lengths = self.encoder(x, lengths) # (T, N, C)
|
output, output_lengths = self.encoder(x, lengths) # (T, N, C)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(output)
|
logits = self.encoder_output_layer(output)
|
||||||
@ -1316,8 +1266,8 @@ class Emformer(EncoderInterface):
|
|||||||
with shape (B, U + right_context_length, D).
|
with shape (B, U + right_context_length, D).
|
||||||
lengths (torch.Tensor):
|
lengths (torch.Tensor):
|
||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x.
|
utterance frames for i-th batch element in x, containing the
|
||||||
It is the true lengths without containing the right_context.
|
right_context at the end.
|
||||||
states (List[List[torch.Tensor]], optional):
|
states (List[List[torch.Tensor]], optional):
|
||||||
Cached states from proceeding chunk's computation, where each
|
Cached states from proceeding chunk's computation, where each
|
||||||
element (List[torch.Tensor]) corresponding to each emformer layer.
|
element (List[torch.Tensor]) corresponding to each emformer layer.
|
||||||
@ -1325,8 +1275,8 @@ class Emformer(EncoderInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor):
|
(Tensor, Tensor):
|
||||||
- output logits, with shape (B, U // 4, D).
|
- output logits, with shape (B, U // 4, D).
|
||||||
- logits lengths, with shape (B,) and i-th element representing
|
- logits lengths, with shape (B,), without containing the
|
||||||
number of valid frames for i-th batch element in output frames.
|
right_context at the end.
|
||||||
- updated states from current chunk's computation.
|
- updated states from current chunk's computation.
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
|
@ -0,0 +1,345 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_attention_forward():
|
||||||
|
from emformer import EmformerAttention
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
U, R = 12, 2
|
||||||
|
chunk_length = 2
|
||||||
|
attention = EmformerAttention(embed_dim=D, nhead=8)
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
S = U // chunk_length
|
||||||
|
M = S - 1
|
||||||
|
else:
|
||||||
|
S, M = 0, 0
|
||||||
|
|
||||||
|
Q, KV = R + U + S, M + R + U
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
|
lengths[0] = U
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
summary = torch.randn(S, B, D)
|
||||||
|
memory = torch.randn(M, B, D)
|
||||||
|
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||||
|
|
||||||
|
output_right_context_utterance, output_memory = attention(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
summary,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||||
|
assert output_memory.shape == (M, B, D)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_attention_infer():
|
||||||
|
from emformer import EmformerAttention
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
R, L = 4, 2
|
||||||
|
chunk_length = 2
|
||||||
|
U = chunk_length
|
||||||
|
attention = EmformerAttention(embed_dim=D, nhead=8)
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
S, M = 1, 3
|
||||||
|
else:
|
||||||
|
S, M = 0, 0
|
||||||
|
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
|
lengths[0] = U
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
summary = torch.randn(S, B, D)
|
||||||
|
memory = torch.randn(M, B, D)
|
||||||
|
left_context_key = torch.randn(L, B, D)
|
||||||
|
left_context_val = torch.randn(L, B, D)
|
||||||
|
|
||||||
|
output_right_context_utterance, output_memory, next_key, next_val = \
|
||||||
|
attention.infer(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
summary,
|
||||||
|
memory,
|
||||||
|
left_context_key,
|
||||||
|
left_context_val,
|
||||||
|
)
|
||||||
|
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||||
|
assert output_memory.shape == (S, B, D)
|
||||||
|
assert next_key.shape == (L + U, B, D)
|
||||||
|
assert next_val.shape == (L + U, B, D)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_layer_forward():
|
||||||
|
from emformer import EmformerLayer
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
U, R, L = 12, 2, 5
|
||||||
|
chunk_length = 2
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
S = U // chunk_length
|
||||||
|
M = S - 1
|
||||||
|
else:
|
||||||
|
S, M = 0, 0
|
||||||
|
|
||||||
|
layer = EmformerLayer(
|
||||||
|
d_model=D,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
left_context_length=L,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
|
||||||
|
Q, KV = R + U + S, M + R + U
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
|
lengths[0] = U
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
memory = torch.randn(M, B, D)
|
||||||
|
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||||
|
|
||||||
|
output_utterance, output_right_context, output_memory = layer(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
assert output_utterance.shape == (U, B, D)
|
||||||
|
assert output_right_context.shape == (R, B, D)
|
||||||
|
assert output_memory.shape == (M, B, D)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_layer_infer():
|
||||||
|
from emformer import EmformerLayer
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
R, L = 2, 5
|
||||||
|
chunk_length = 2
|
||||||
|
U = chunk_length
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
|
||||||
|
layer = EmformerLayer(
|
||||||
|
d_model=D,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
left_context_length=L,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
|
||||||
|
utterance = torch.randn(U, B, D)
|
||||||
|
lengths = torch.randint(1, U + 1, (B,))
|
||||||
|
lengths[0] = U
|
||||||
|
right_context = torch.randn(R, B, D)
|
||||||
|
memory = torch.randn(M, B, D)
|
||||||
|
state = None
|
||||||
|
output_utterance, output_right_context, output_memory, output_state = \
|
||||||
|
layer.infer(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
assert output_utterance.shape == (U, B, D)
|
||||||
|
assert output_right_context.shape == (R, B, D)
|
||||||
|
if use_memory:
|
||||||
|
assert output_memory.shape == (1, B, D)
|
||||||
|
else:
|
||||||
|
assert output_memory.shape == (0, B, D)
|
||||||
|
assert len(output_state) == 4
|
||||||
|
assert output_state[0].shape == (M, B, D)
|
||||||
|
assert output_state[1].shape == (L, B, D)
|
||||||
|
assert output_state[2].shape == (L, B, D)
|
||||||
|
assert output_state[3].shape == (1, B)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_forward():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
U, R, L = 12, 2, 5
|
||||||
|
chunk_length = 2
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
S = U // chunk_length
|
||||||
|
M = S - 1
|
||||||
|
else:
|
||||||
|
S, M = 0, 0
|
||||||
|
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=2,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.randn(U + R, B, D)
|
||||||
|
lengths = torch.randint(1, U + R + 1, (B,))
|
||||||
|
lengths[0] = U + R
|
||||||
|
|
||||||
|
output, output_lengths = encoder(x, lengths)
|
||||||
|
assert output.shape == (U, B, D)
|
||||||
|
assert torch.equal(
|
||||||
|
output_lengths, torch.clamp(lengths - R, min=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_infer():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
B, D = 2, 256
|
||||||
|
R, L = 2, 5
|
||||||
|
chunk_length = 2
|
||||||
|
U = chunk_length
|
||||||
|
num_chunks = 3
|
||||||
|
num_encoder_layers = 2
|
||||||
|
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
)
|
||||||
|
|
||||||
|
states = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
x = torch.randn(U + R, B, D)
|
||||||
|
lengths = torch.randint(1, U + R + 1, (B,))
|
||||||
|
lengths[0] = U + R
|
||||||
|
output, output_lengths, states = \
|
||||||
|
encoder.infer(x, lengths, states)
|
||||||
|
assert output.shape == (U, B, D)
|
||||||
|
assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0))
|
||||||
|
assert len(states) == num_encoder_layers
|
||||||
|
for state in states:
|
||||||
|
assert len(state) == 4
|
||||||
|
assert state[0].shape == (M, B, D)
|
||||||
|
assert state[1].shape == (L, B, D)
|
||||||
|
assert state[2].shape == (L, B, D)
|
||||||
|
assert torch.equal(
|
||||||
|
state[3], (chunk_idx + 1) * U * torch.ones_like(state[3])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_forward():
|
||||||
|
from emformer import Emformer
|
||||||
|
num_features = 80
|
||||||
|
output_dim = 1000
|
||||||
|
chunk_length = 16
|
||||||
|
L, R = 32, 16
|
||||||
|
B, D, U = 2, 256, 48
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
output_dim=output_dim,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=D,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
vgg_frontend=False,
|
||||||
|
)
|
||||||
|
x = torch.randn(B, U + R, num_features)
|
||||||
|
x_lens = torch.randint(1, U + R + 1, (B,))
|
||||||
|
x_lens[0] = U + R
|
||||||
|
logits, output_lengths = model(x, x_lens)
|
||||||
|
assert logits.shape == (B, U // 4, output_dim)
|
||||||
|
assert torch.equal(
|
||||||
|
output_lengths, torch.clamp(x_lens // 4 - R // 4, min=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_infer():
|
||||||
|
from emformer import Emformer
|
||||||
|
num_features = 80
|
||||||
|
output_dim = 1000
|
||||||
|
chunk_length = 16
|
||||||
|
U = chunk_length
|
||||||
|
L, R = 32, 16
|
||||||
|
B, D = 2, 256
|
||||||
|
num_chunks = 3
|
||||||
|
num_encoder_layers = 2
|
||||||
|
for use_memory in [True, False]:
|
||||||
|
if use_memory:
|
||||||
|
M = 3
|
||||||
|
else:
|
||||||
|
M = 0
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
output_dim=output_dim,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=D,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
vgg_frontend=False,
|
||||||
|
)
|
||||||
|
states = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
x = torch.randn(B, U + R, num_features)
|
||||||
|
x_lens = torch.randint(1, U + R + 1, (B,))
|
||||||
|
x_lens[0] = U + R
|
||||||
|
logits, output_lengths, states = \
|
||||||
|
model.infer(x, x_lens, states)
|
||||||
|
assert logits.shape == (B, U // 4, output_dim)
|
||||||
|
assert torch.equal(
|
||||||
|
output_lengths, torch.clamp(x_lens // 4 - R // 4, min=0)
|
||||||
|
)
|
||||||
|
assert len(states) == num_encoder_layers
|
||||||
|
for state in states:
|
||||||
|
assert len(state) == 4
|
||||||
|
assert state[0].shape == (M, B, D)
|
||||||
|
assert state[1].shape == (L // 4, B, D)
|
||||||
|
assert state[2].shape == (L // 4, B, D)
|
||||||
|
assert torch.equal(
|
||||||
|
state[3],
|
||||||
|
(chunk_idx + 1) * U // 4 * torch.ones_like(state[3])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_emformer_attention_forward()
|
||||||
|
test_emformer_attention_infer()
|
||||||
|
test_emformer_layer_forward()
|
||||||
|
test_emformer_layer_infer()
|
||||||
|
test_emformer_encoder_forward()
|
||||||
|
test_emformer_encoder_infer()
|
||||||
|
test_emformer_forward()
|
||||||
|
test_emformer_infer()
|
Loading…
x
Reference in New Issue
Block a user