add doc about memory

This commit is contained in:
yaozengwei 2022-05-25 17:09:49 +08:00
parent fb8e926521
commit 435a01dbdf

View File

@ -730,15 +730,14 @@ class EmformerAttention(nn.Module):
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance) matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
# (B, nhead, U, U + right_context_length) for training and validation mode; # noqa # (B, nhead, U, U + right_context_length) for training and validation mode; # noqa
# (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa # (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa
matrix_bd_utterance = matrix_bd_utterance.contiguous().view( matrix_bd_utterance = matrix_bd_utterance.view(B * self.nhead, U, -1)
B * self.nhead, U, -1
)
matrix_bd = torch.zeros_like(matrix_ac) matrix_bd = torch.zeros_like(matrix_ac)
if left_context_key is not None and left_context_val is not None: if left_context_key is not None and left_context_val is not None:
# inference mode # inference mode
# key: [memory, right context, left context, utterance] # key: [memory, right context, left context, utterance]
# for memory # for memory
if M > 0: if M > 0:
# take average over the chunk frames for the memory vector
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d( matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1), matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1),
kernel_size=(1, self.chunk_length), kernel_size=(1, self.chunk_length),
@ -758,6 +757,7 @@ class EmformerAttention(nn.Module):
# key: [memory, right context, utterance] # key: [memory, right context, utterance]
# for memory # for memory
if M > 0: if M > 0:
# take average over the chunk frames for the memory vector
matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d( matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d(
matrix_bd_utterance[:, :, :U].unsqueeze(1), matrix_bd_utterance[:, :, :U].unsqueeze(1),
kernel_size=(1, self.chunk_length), kernel_size=(1, self.chunk_length),