update test functions for conv_emformer_transducer/emformer.py

This commit is contained in:
yaozengwei 2022-04-14 19:16:30 +08:00
parent 32420cc3e4
commit df7919f4bf
2 changed files with 215 additions and 10 deletions

View File

@ -14,8 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
# It is modified based on # It is modified based on https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. # noqa
# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py.
import math import math
import warnings import warnings
@ -56,8 +55,6 @@ class EmformerAttention(nn.Module):
Embedding dimension. Embedding dimension.
nhead (int): nhead (int):
Number of attention heads in each Emformer layer. Number of attention heads in each Emformer layer.
dropout (float, optional):
Dropout probability. (Default: 0.0)
tanh_on_mem (bool, optional): tanh_on_mem (bool, optional):
If ``True``, applies tanh to memory elements. (Default: ``False``) If ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional): negative_inf (float, optional):
@ -68,7 +65,6 @@ class EmformerAttention(nn.Module):
self, self,
embed_dim: int, embed_dim: int,
nhead: int, nhead: int,
dropout: float = 0.0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
): ):
@ -82,7 +78,6 @@ class EmformerAttention(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.nhead = nhead self.nhead = nhead
self.dropout = dropout
self.tanh_on_mem = tanh_on_mem self.tanh_on_mem = tanh_on_mem
self.negative_inf = negative_inf self.negative_inf = negative_inf
@ -154,9 +149,7 @@ class EmformerAttention(nn.Module):
attention_probs = nn.functional.softmax( attention_probs = nn.functional.softmax(
attention_weights_float, dim=-1 attention_weights_float, dim=-1
).type_as(attention_weights) ).type_as(attention_weights)
attention_probs = nn.functional.dropout(
attention_probs, p=float(self.dropout), training=self.training
)
return attention_probs return attention_probs
def _forward_impl( def _forward_impl(
@ -481,7 +474,6 @@ class EmformerLayer(nn.Module):
self.attention = EmformerAttention( self.attention = EmformerAttention(
embed_dim=d_model, embed_dim=d_model,
nhead=nhead, nhead=nhead,
dropout=0.0,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
) )

View File

@ -366,6 +366,216 @@ def test_emformer_infer():
assert conv_cache.shape == (B, D, K - 1) assert conv_cache.shape == (B, D, K - 1)
def test_emformer_encoder_layer_forward_infer_consistency():
from emformer import EmformerEncoder
chunk_length = 4
num_chunks = 3
U = chunk_length * num_chunks
L, R = 1, 2
D = 256
num_encoder_layers = 1
memory_sizes = [0, 3]
K = 3
for M in memory_sizes:
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
dropout=0.1,
cnn_module_kernel=K,
causal=True,
)
encoder.eval()
encoder_layer = encoder.emformer_layers[0]
x = torch.randn(U + R, 1, D)
lengths = torch.tensor([U])
right_context = encoder._gen_right_context(x)
utterance = x[: x.size(0) - R]
attention_mask = encoder._gen_attention_mask(utterance)
memory = (
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
:-1
]
if encoder.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
(
forward_output_utterance,
forward_output_right_context,
forward_output_memory,
) = encoder_layer(
utterance,
lengths,
right_context,
memory,
attention_mask,
)
state = None
conv_cache = None
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[start_idx:end_idx]
chunk_right_context = x[end_idx : end_idx + R] # noqa
chunk_length = torch.tensor([chunk_length])
chunk_memory = (
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
if encoder.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
(
infer_output_chunk,
infer_right_context,
infer_output_memory,
state,
conv_cache,
) = encoder_layer.infer(
chunk,
chunk_length,
chunk_right_context,
chunk_memory,
state,
conv_cache,
)
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
assert torch.allclose(
infer_output_chunk,
forward_output_chunk,
atol=1e-5,
rtol=0.0,
)
def test_emformer_encoder_forward_infer_consistency():
from emformer import EmformerEncoder
chunk_length = 4
num_chunks = 3
U = chunk_length * num_chunks
L, R = 1, 2
D = 256
num_encoder_layers = 3
K = 3
memory_sizes = [0, 3]
for M in memory_sizes:
encoder = EmformerEncoder(
chunk_length=chunk_length,
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
dropout=0.1,
cnn_module_kernel=K,
causal=True,
)
encoder.eval()
x = torch.randn(U + R, 1, D)
lengths = torch.tensor([U + R])
forward_output, forward_output_lengths = encoder(x, lengths)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[start_idx : end_idx + R] # noqa
chunk_right_context = x[end_idx : end_idx + R] # noqa
chunk_length = torch.tensor([chunk_length])
(
infer_output_chunk,
infer_output_lengths,
states,
conv_caches,
) = encoder.infer(
chunk,
chunk_length,
states,
conv_caches,
)
forward_output_chunk = forward_output[start_idx:end_idx]
assert torch.allclose(
infer_output_chunk,
forward_output_chunk,
atol=1e-5,
rtol=0.0,
)
def test_emformer_forward_infer_consistency():
from emformer import Emformer
num_features = 80
output_dim = 1000
chunk_length = 8
num_chunks = 3
U = chunk_length * num_chunks
L, R = 128, 4
D = 256
num_encoder_layers = 2
K = 3
memory_sizes = [0, 3]
for M in memory_sizes:
model = Emformer(
num_features=num_features,
output_dim=output_dim,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=K,
left_context_length=L,
right_context_length=R,
max_memory_size=M,
dropout=0.1,
vgg_frontend=False,
causal=True,
)
model.eval()
x = torch.randn(1, U + R + 3, num_features)
x_lens = torch.tensor([x.size(1)])
# forward mode
forward_logits, _ = model(x, x_lens)
states = None
conv_caches = None
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[:, start_idx : end_idx + R + 3] # noqa
lengths = torch.tensor([chunk.size(1)])
(
infer_chunk_logits,
output_lengths,
states,
conv_caches,
) = model.infer(chunk, lengths, states, conv_caches)
forward_chunk_logits = forward_logits[
:, start_idx // 4 : end_idx // 4 # noqa
]
assert torch.allclose(
infer_chunk_logits,
forward_chunk_logits,
atol=1e-5,
rtol=0.0,
)
if __name__ == "__main__": if __name__ == "__main__":
test_emformer_attention_forward() test_emformer_attention_forward()
test_emformer_attention_infer() test_emformer_attention_infer()
@ -375,3 +585,6 @@ if __name__ == "__main__":
test_emformer_encoder_infer() test_emformer_encoder_infer()
test_emformer_forward() test_emformer_forward()
test_emformer_infer() test_emformer_infer()
test_emformer_encoder_layer_forward_infer_consistency()
test_emformer_encoder_forward_infer_consistency()
test_emformer_forward_infer_consistency()