mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
delete duplicated dropout in emformer attention and update emformer test codes.
This commit is contained in:
parent
c2808f8541
commit
4130892971
@ -183,9 +183,9 @@ class EmformerAttention(nn.Module):
|
|||||||
attention_probs = nn.functional.softmax(
|
attention_probs = nn.functional.softmax(
|
||||||
attention_weights_float, dim=-1
|
attention_weights_float, dim=-1
|
||||||
).type_as(attention_weights)
|
).type_as(attention_weights)
|
||||||
attention_probs = nn.functional.dropout(
|
# attention_probs = nn.functional.dropout(
|
||||||
attention_probs, p=float(self.dropout), training=self.training
|
# attention_probs, p=float(self.dropout), training=self.training
|
||||||
)
|
# )
|
||||||
return attention_probs
|
return attention_probs
|
||||||
|
|
||||||
def _forward_impl(
|
def _forward_impl(
|
||||||
@ -955,16 +955,15 @@ class EmformerEncoder(nn.Module):
|
|||||||
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Hard copy each chunk's right context and concat them."""
|
"""Hard copy each chunk's right context and concat them."""
|
||||||
T = x.shape[0]
|
T = x.shape[0]
|
||||||
num_segs = math.ceil(
|
num_chunks = math.ceil(
|
||||||
(T - self.right_context_length) / self.chunk_length
|
(T - self.right_context_length) / self.chunk_length
|
||||||
)
|
)
|
||||||
right_context_blocks = []
|
right_context_blocks = []
|
||||||
for seg_idx in range(num_segs - 1):
|
for seg_idx in range(num_chunks - 1):
|
||||||
start = (seg_idx + 1) * self.chunk_length
|
start = (seg_idx + 1) * self.chunk_length
|
||||||
end = start + self.right_context_length
|
end = start + self.right_context_length
|
||||||
right_context_blocks.append(x[start:end])
|
right_context_blocks.append(x[start:end])
|
||||||
last_right_context_start_idx = T - self.right_context_length
|
right_context_blocks.append(x[T - self.right_context_length :]) # noqa
|
||||||
right_context_blocks.append(x[last_right_context_start_idx:])
|
|
||||||
return torch.cat(right_context_blocks)
|
return torch.cat(right_context_blocks)
|
||||||
|
|
||||||
def _gen_attention_mask_col_widths(
|
def _gen_attention_mask_col_widths(
|
||||||
|
@ -342,12 +342,218 @@ def test_emformer_infer():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_attention_forward_infer_consistency():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
chunk_length = 4
|
||||||
|
num_chunks = 3
|
||||||
|
U = chunk_length * num_chunks
|
||||||
|
L, R = 1, 2
|
||||||
|
D = 256
|
||||||
|
num_encoder_layers = 1
|
||||||
|
memory_sizes = [0, 3]
|
||||||
|
|
||||||
|
for M in memory_sizes:
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
encoder_layer = encoder.emformer_layers[0]
|
||||||
|
|
||||||
|
x = torch.randn(U + R, 1, D)
|
||||||
|
lengths = torch.tensor([U])
|
||||||
|
right_context = encoder._gen_right_context(x)
|
||||||
|
utterance = x[: x.size(0) - R]
|
||||||
|
attention_mask = encoder._gen_attention_mask(utterance)
|
||||||
|
memory = (
|
||||||
|
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||||
|
:-1
|
||||||
|
]
|
||||||
|
if encoder.use_memory
|
||||||
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
forward_output_right_context_utterance,
|
||||||
|
forward_output_memory,
|
||||||
|
) = encoder_layer._apply_attention_forward(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
forward_output_utterance = forward_output_right_context_utterance[
|
||||||
|
right_context.size(0) : # noqa
|
||||||
|
]
|
||||||
|
|
||||||
|
state = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = x[start_idx:end_idx]
|
||||||
|
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||||
|
chunk_length = torch.tensor([chunk_length])
|
||||||
|
chunk_memory = (
|
||||||
|
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
|
||||||
|
if encoder.use_memory
|
||||||
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
infer_output_right_context_utterance,
|
||||||
|
infer_output_memory,
|
||||||
|
state,
|
||||||
|
) = encoder_layer._apply_attention_infer(
|
||||||
|
chunk,
|
||||||
|
chunk_length,
|
||||||
|
chunk_right_context,
|
||||||
|
chunk_memory,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
infer_output_utterance = infer_output_right_context_utterance[
|
||||||
|
chunk_right_context.size(0) : # noqa
|
||||||
|
]
|
||||||
|
print(
|
||||||
|
infer_output_utterance
|
||||||
|
- forward_output_utterance[start_idx:end_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_layer_forward_infer_consistency():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
chunk_length = 4
|
||||||
|
num_chunks = 3
|
||||||
|
U = chunk_length * num_chunks
|
||||||
|
L, R = 1, 2
|
||||||
|
D = 256
|
||||||
|
num_encoder_layers = 1
|
||||||
|
memory_sizes = [0, 3]
|
||||||
|
|
||||||
|
for M in memory_sizes:
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
encoder_layer = encoder.emformer_layers[0]
|
||||||
|
|
||||||
|
x = torch.randn(U + R, 1, D)
|
||||||
|
lengths = torch.tensor([U])
|
||||||
|
right_context = encoder._gen_right_context(x)
|
||||||
|
utterance = x[: x.size(0) - R]
|
||||||
|
attention_mask = encoder._gen_attention_mask(utterance)
|
||||||
|
memory = (
|
||||||
|
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||||
|
:-1
|
||||||
|
]
|
||||||
|
if encoder.use_memory
|
||||||
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
forward_output_utterance,
|
||||||
|
forward_output_right_context,
|
||||||
|
forward_output_memory,
|
||||||
|
) = encoder_layer(
|
||||||
|
utterance,
|
||||||
|
lengths,
|
||||||
|
right_context,
|
||||||
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = x[start_idx:end_idx]
|
||||||
|
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||||
|
chunk_length = torch.tensor([chunk_length])
|
||||||
|
chunk_memory = (
|
||||||
|
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
|
||||||
|
if encoder.use_memory
|
||||||
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
infer_output_utterance,
|
||||||
|
infer_right_context,
|
||||||
|
infer_output_memory,
|
||||||
|
state,
|
||||||
|
) = encoder_layer.infer(
|
||||||
|
chunk,
|
||||||
|
chunk_length,
|
||||||
|
chunk_right_context,
|
||||||
|
chunk_memory,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
infer_output_utterance
|
||||||
|
- forward_output_utterance[start_idx:end_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_encoder_forward_infer_consistency():
|
||||||
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
|
chunk_length = 4
|
||||||
|
num_chunks = 3
|
||||||
|
U = chunk_length * num_chunks
|
||||||
|
L, R = 1, 2
|
||||||
|
D = 256
|
||||||
|
num_encoder_layers = 3
|
||||||
|
memory_sizes = [0, 3]
|
||||||
|
|
||||||
|
for M in memory_sizes:
|
||||||
|
encoder = EmformerEncoder(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
d_model=D,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
left_context_length=L,
|
||||||
|
right_context_length=R,
|
||||||
|
max_memory_size=M,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.randn(U + R, 1, D)
|
||||||
|
lengths = torch.tensor([U + R])
|
||||||
|
|
||||||
|
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||||
|
|
||||||
|
states = None
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start_idx = chunk_idx * chunk_length
|
||||||
|
end_idx = start_idx + chunk_length
|
||||||
|
chunk = x[start_idx : end_idx + R] # noqa
|
||||||
|
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||||
|
chunk_length = torch.tensor([chunk_length])
|
||||||
|
infer_output, infer_output_lengths, states = encoder.infer(
|
||||||
|
chunk,
|
||||||
|
chunk_length,
|
||||||
|
states,
|
||||||
|
)
|
||||||
|
print(infer_output - forward_output[start_idx:end_idx])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_emformer_attention_forward()
|
# test_emformer_attention_forward()
|
||||||
test_emformer_attention_infer()
|
# test_emformer_attention_infer()
|
||||||
test_emformer_layer_forward()
|
# test_emformer_layer_forward()
|
||||||
test_emformer_layer_infer()
|
# test_emformer_layer_infer()
|
||||||
test_emformer_encoder_forward()
|
# test_emformer_encoder_forward()
|
||||||
test_emformer_encoder_infer()
|
# test_emformer_encoder_infer()
|
||||||
test_emformer_forward()
|
# test_emformer_forward()
|
||||||
test_emformer_infer()
|
# test_emformer_infer()
|
||||||
|
# test_emformer_attention_forward_infer_consistency()
|
||||||
|
# test_emformer_layer_forward_infer_consistency()
|
||||||
|
test_emformer_encoder_forward_infer_consistency()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user