mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Initial drafts/work on bidirectional conformer
This commit is contained in:
parent
2b0370eb18
commit
cfdfcf657d
File diff suppressed because it is too large
Load Diff
@ -939,7 +939,7 @@ def decoder_padding_mask(
|
|||||||
return ys_mask
|
return ys_mask
|
||||||
|
|
||||||
|
|
||||||
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
def generate_square_subsequent_mask(sz: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
|
||||||
"""Generate a square mask for the sequence. The masked positions are
|
"""Generate a square mask for the sequence. The masked positions are
|
||||||
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
||||||
The mask can be used for masked self-attention.
|
The mask can be used for masked self-attention.
|
||||||
@ -956,7 +956,7 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
A square mask of dimension (sz, sz)
|
A square mask of dimension (sz, sz)
|
||||||
"""
|
"""
|
||||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1)
|
||||||
mask = (
|
mask = (
|
||||||
mask.float()
|
mask.float()
|
||||||
.masked_fill(mask == 0, float("-inf"))
|
.masked_fill(mask == 0, float("-inf"))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user