mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add doc about memory
This commit is contained in:
parent
fb8e926521
commit
435a01dbdf
@ -730,15 +730,14 @@ class EmformerAttention(nn.Module):
|
|||||||
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
||||||
# (B, nhead, U, U + right_context_length) for training and validation mode; # noqa
|
# (B, nhead, U, U + right_context_length) for training and validation mode; # noqa
|
||||||
# (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa
|
# (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa
|
||||||
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
matrix_bd_utterance = matrix_bd_utterance.view(B * self.nhead, U, -1)
|
||||||
B * self.nhead, U, -1
|
|
||||||
)
|
|
||||||
matrix_bd = torch.zeros_like(matrix_ac)
|
matrix_bd = torch.zeros_like(matrix_ac)
|
||||||
if left_context_key is not None and left_context_val is not None:
|
if left_context_key is not None and left_context_val is not None:
|
||||||
# inference mode
|
# inference mode
|
||||||
# key: [memory, right context, left context, utterance]
|
# key: [memory, right context, left context, utterance]
|
||||||
# for memory
|
# for memory
|
||||||
if M > 0:
|
if M > 0:
|
||||||
|
# take average over the chunk frames for the memory vector
|
||||||
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
|
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
|
||||||
matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1),
|
matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1),
|
||||||
kernel_size=(1, self.chunk_length),
|
kernel_size=(1, self.chunk_length),
|
||||||
@ -758,6 +757,7 @@ class EmformerAttention(nn.Module):
|
|||||||
# key: [memory, right context, utterance]
|
# key: [memory, right context, utterance]
|
||||||
# for memory
|
# for memory
|
||||||
if M > 0:
|
if M > 0:
|
||||||
|
# take average over the chunk frames for the memory vector
|
||||||
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
|
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
|
||||||
matrix_bd_utterance[:, :, :U].unsqueeze(1),
|
matrix_bd_utterance[:, :, :U].unsqueeze(1),
|
||||||
kernel_size=(1, self.chunk_length),
|
kernel_size=(1, self.chunk_length),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user