minor fix of doc of pos_emb

This commit is contained in:
yaozengwei 2022-05-10 12:15:14 +08:00
parent 5de9d0a19a
commit 61ecd3764d
2 changed files with 11 additions and 2 deletions

View File

@ -480,6 +480,7 @@ def main():
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -550,8 +550,8 @@ class EmformerAttention(nn.Module):
Cached attention value of left context from preceding computation,
with shape (L, B, D).
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For infer mode, PE = L+2*U-1.
Position encoding embedding, with shape (PE, D),
where PE = L + 2 * U - 1.
Returns:
A tuple containing 4 tensors:
@ -1264,6 +1264,10 @@ class EmformerEncoder(nn.Module):
right_context at the end.
"""
U = x.size(0) - self.right_context_length
# for query of [utterance] (i), key-value [utterance] (j),
# the max relative distance i - j is U - 1
# the min relative distance i - j is -(U - 1)
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
right_context = self._gen_right_context(x)
@ -1329,8 +1333,12 @@ class EmformerEncoder(nn.Module):
f"expected size of {self.chunk_length + self.right_context_length} "
f"for dimension 1 of x, but got {x.size(1)}."
)
pos_len = self.chunk_length + self.left_context_length
neg_len = self.chunk_length
# for query of [utterance] (i), key-value [left_context, utterance] (j),
# the max relative distance i - j is L + U - 1
# the min relative distance i - j is -(U - 1)
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
right_context = x[self.chunk_length :]