mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Initial drafts/work on bidirectional conformer
This commit is contained in:
parent
2b0370eb18
commit
cfdfcf657d
File diff suppressed because it is too large
Load Diff
@ -939,7 +939,7 @@ def decoder_padding_mask(
|
||||
return ys_mask
|
||||
|
||||
|
||||
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
||||
def generate_square_subsequent_mask(sz: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
|
||||
"""Generate a square mask for the sequence. The masked positions are
|
||||
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
||||
The mask can be used for masked self-attention.
|
||||
@ -956,7 +956,7 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
||||
Returns:
|
||||
A square mask of dimension (sz, sz)
|
||||
"""
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1)
|
||||
mask = (
|
||||
mask.float()
|
||||
.masked_fill(mask == 0, float("-inf"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user