mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 07:34:21 +00:00
update test functions for conv_emformer_transducer/emformer.py
This commit is contained in:
parent
32420cc3e4
commit
df7919f4bf
@ -14,8 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# It is modified based on
|
||||
# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py.
|
||||
# It is modified based on https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. # noqa
|
||||
|
||||
import math
|
||||
import warnings
|
||||
@ -56,8 +55,6 @@ class EmformerAttention(nn.Module):
|
||||
Embedding dimension.
|
||||
nhead (int):
|
||||
Number of attention heads in each Emformer layer.
|
||||
dropout (float, optional):
|
||||
Dropout probability. (Default: 0.0)
|
||||
tanh_on_mem (bool, optional):
|
||||
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||
negative_inf (float, optional):
|
||||
@ -68,7 +65,6 @@ class EmformerAttention(nn.Module):
|
||||
self,
|
||||
embed_dim: int,
|
||||
nhead: int,
|
||||
dropout: float = 0.0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
):
|
||||
@ -82,7 +78,6 @@ class EmformerAttention(nn.Module):
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.nhead = nhead
|
||||
self.dropout = dropout
|
||||
self.tanh_on_mem = tanh_on_mem
|
||||
self.negative_inf = negative_inf
|
||||
|
||||
@ -154,9 +149,7 @@ class EmformerAttention(nn.Module):
|
||||
attention_probs = nn.functional.softmax(
|
||||
attention_weights_float, dim=-1
|
||||
).type_as(attention_weights)
|
||||
attention_probs = nn.functional.dropout(
|
||||
attention_probs, p=float(self.dropout), training=self.training
|
||||
)
|
||||
|
||||
return attention_probs
|
||||
|
||||
def _forward_impl(
|
||||
@ -481,7 +474,6 @@ class EmformerLayer(nn.Module):
|
||||
self.attention = EmformerAttention(
|
||||
embed_dim=d_model,
|
||||
nhead=nhead,
|
||||
dropout=0.0,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
)
|
||||
|
@ -366,6 +366,216 @@ def test_emformer_infer():
|
||||
assert conv_cache.shape == (B, D, K - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_layer_forward_infer_consistency():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
chunk_length = 4
|
||||
num_chunks = 3
|
||||
U = chunk_length * num_chunks
|
||||
L, R = 1, 2
|
||||
D = 256
|
||||
num_encoder_layers = 1
|
||||
memory_sizes = [0, 3]
|
||||
K = 3
|
||||
|
||||
for M in memory_sizes:
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
dropout=0.1,
|
||||
cnn_module_kernel=K,
|
||||
causal=True,
|
||||
)
|
||||
encoder.eval()
|
||||
encoder_layer = encoder.emformer_layers[0]
|
||||
|
||||
x = torch.randn(U + R, 1, D)
|
||||
lengths = torch.tensor([U])
|
||||
right_context = encoder._gen_right_context(x)
|
||||
utterance = x[: x.size(0) - R]
|
||||
attention_mask = encoder._gen_attention_mask(utterance)
|
||||
memory = (
|
||||
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||
:-1
|
||||
]
|
||||
if encoder.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
(
|
||||
forward_output_utterance,
|
||||
forward_output_right_context,
|
||||
forward_output_memory,
|
||||
) = encoder_layer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
state = None
|
||||
conv_cache = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[start_idx:end_idx]
|
||||
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||
chunk_length = torch.tensor([chunk_length])
|
||||
chunk_memory = (
|
||||
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
|
||||
if encoder.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
(
|
||||
infer_output_chunk,
|
||||
infer_right_context,
|
||||
infer_output_memory,
|
||||
state,
|
||||
conv_cache,
|
||||
) = encoder_layer.infer(
|
||||
chunk,
|
||||
chunk_length,
|
||||
chunk_right_context,
|
||||
chunk_memory,
|
||||
state,
|
||||
conv_cache,
|
||||
)
|
||||
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
|
||||
assert torch.allclose(
|
||||
infer_output_chunk,
|
||||
forward_output_chunk,
|
||||
atol=1e-5,
|
||||
rtol=0.0,
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_encoder_forward_infer_consistency():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
chunk_length = 4
|
||||
num_chunks = 3
|
||||
U = chunk_length * num_chunks
|
||||
L, R = 1, 2
|
||||
D = 256
|
||||
num_encoder_layers = 3
|
||||
K = 3
|
||||
memory_sizes = [0, 3]
|
||||
|
||||
for M in memory_sizes:
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
dropout=0.1,
|
||||
cnn_module_kernel=K,
|
||||
causal=True,
|
||||
)
|
||||
encoder.eval()
|
||||
|
||||
x = torch.randn(U + R, 1, D)
|
||||
lengths = torch.tensor([U + R])
|
||||
|
||||
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[start_idx : end_idx + R] # noqa
|
||||
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||
chunk_length = torch.tensor([chunk_length])
|
||||
(
|
||||
infer_output_chunk,
|
||||
infer_output_lengths,
|
||||
states,
|
||||
conv_caches,
|
||||
) = encoder.infer(
|
||||
chunk,
|
||||
chunk_length,
|
||||
states,
|
||||
conv_caches,
|
||||
)
|
||||
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||
assert torch.allclose(
|
||||
infer_output_chunk,
|
||||
forward_output_chunk,
|
||||
atol=1e-5,
|
||||
rtol=0.0,
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_forward_infer_consistency():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
output_dim = 1000
|
||||
chunk_length = 8
|
||||
num_chunks = 3
|
||||
U = chunk_length * num_chunks
|
||||
L, R = 128, 4
|
||||
D = 256
|
||||
num_encoder_layers = 2
|
||||
K = 3
|
||||
memory_sizes = [0, 3]
|
||||
|
||||
for M in memory_sizes:
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
output_dim=output_dim,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=K,
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
max_memory_size=M,
|
||||
dropout=0.1,
|
||||
vgg_frontend=False,
|
||||
causal=True,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
x = torch.randn(1, U + R + 3, num_features)
|
||||
x_lens = torch.tensor([x.size(1)])
|
||||
|
||||
# forward mode
|
||||
forward_logits, _ = model(x, x_lens)
|
||||
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[:, start_idx : end_idx + R + 3] # noqa
|
||||
lengths = torch.tensor([chunk.size(1)])
|
||||
(
|
||||
infer_chunk_logits,
|
||||
output_lengths,
|
||||
states,
|
||||
conv_caches,
|
||||
) = model.infer(chunk, lengths, states, conv_caches)
|
||||
forward_chunk_logits = forward_logits[
|
||||
:, start_idx // 4 : end_idx // 4 # noqa
|
||||
]
|
||||
assert torch.allclose(
|
||||
infer_chunk_logits,
|
||||
forward_chunk_logits,
|
||||
atol=1e-5,
|
||||
rtol=0.0,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_emformer_attention_forward()
|
||||
test_emformer_attention_infer()
|
||||
@ -375,3 +585,6 @@ if __name__ == "__main__":
|
||||
test_emformer_encoder_infer()
|
||||
test_emformer_forward()
|
||||
test_emformer_infer()
|
||||
test_emformer_encoder_layer_forward_infer_consistency()
|
||||
test_emformer_encoder_forward_infer_consistency()
|
||||
test_emformer_forward_infer_consistency()
|
||||
|
Loading…
x
Reference in New Issue
Block a user