from local

This commit is contained in:
dohe0342 2023-01-24 14:09:46 +09:00
parent 8fab7905b8
commit 52ae2ba762
2 changed files with 17 additions and 0 deletions

View File

@ -0,0 +1,17 @@
import math
import torch.nn.functional as F
def pad_to_multiple(x, multiple, dim=-1, value=0):
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
if x is None:
return None, 0
tsz = x.size(dim)
m = tsz / multiple
remainder = math.ceil(m) * multiple - tsz
if m.is_integer():
return x, 0
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
~