diff --git a/egs/LJSpeech/ASR/pruned_transducer_stateless_d2v_v2/.utils.py.swp b/egs/LJSpeech/ASR/pruned_transducer_stateless_d2v_v2/.utils.py.swp index 261bb3d47..ec4727030 100644 Binary files a/egs/LJSpeech/ASR/pruned_transducer_stateless_d2v_v2/.utils.py.swp and b/egs/LJSpeech/ASR/pruned_transducer_stateless_d2v_v2/.utils.py.swp differ diff --git a/egs/LJSpeech/ASR/pruned_transducer_stateless_d2v_v2/utils.py b/egs/LJSpeech/ASR/pruned_transducer_stateless_d2v_v2/utils.py new file mode 100644 index 000000000..91ffcd85c --- /dev/null +++ b/egs/LJSpeech/ASR/pruned_transducer_stateless_d2v_v2/utils.py @@ -0,0 +1,17 @@ +import math +import torch.nn.functional as F + + +def pad_to_multiple(x, multiple, dim=-1, value=0): + # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 + if x is None: + return None, 0 + tsz = x.size(dim) + m = tsz / multiple + remainder = math.ceil(m) * multiple - tsz + if m.is_integer(): + return x, 0 + pad_offset = (0,) * (-1 - dim) * 2 + + return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder +~