mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove unused debug statement.
This commit is contained in:
parent
8ccd061051
commit
db543866d8
@ -471,7 +471,7 @@ class Zipformer(EncoderInterface):
|
|||||||
seq_len = x.shape[0]
|
seq_len = x.shape[0]
|
||||||
|
|
||||||
# t is frame index, shape (seq_len,)
|
# t is frame index, shape (seq_len,)
|
||||||
t = torch.arange(seq_len, dtype=torch.int32)
|
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||||
# c is chunk index for each frame, shape (seq_len,)
|
# c is chunk index for each frame, shape (seq_len,)
|
||||||
c = t // chunk_size
|
c = t // chunk_size
|
||||||
src_c = c
|
src_c = c
|
||||||
@ -479,7 +479,7 @@ class Zipformer(EncoderInterface):
|
|||||||
|
|
||||||
attn_mask = torch.logical_or(src_c > tgt_c,
|
attn_mask = torch.logical_or(src_c > tgt_c,
|
||||||
src_c < tgt_c - left_context_chunks)
|
src_c < tgt_c - left_context_chunks)
|
||||||
if __name__ == "__main__" or random.random() < 0.1:
|
if __name__ == "__main__":
|
||||||
logging.info(f"attn_mask = {attn_mask}")
|
logging.info(f"attn_mask = {attn_mask}")
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
@ -1412,7 +1412,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert attn_mask.dtype == torch.bool
|
assert attn_mask.dtype == torch.bool
|
||||||
attn_scores.masked_fill_(attn_mask, float("-inf"))
|
attn_scores = attn_scores.masked_fill(attn_mask, float("-inf"))
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user