Update utils.py

This commit is contained in:
zr_jin 2025-01-27 16:02:06 +08:00 committed by GitHub
parent d679567814
commit 361f3b2061
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,8 +38,8 @@ def default(v, d):
def lens_to_mask(
t: int["b"], length: int | None = None
) -> bool["b n"]: # noqa: F722 F821
t: int["b"], length: int | None = None # noqa: F722 F821
) -> bool["b n"]:
if not exists(length):
length = t.amax()
@ -48,8 +48,8 @@ def lens_to_mask(
def mask_from_start_end_indices(
seq_len: int["b"], start: int["b"], end: int["b"]
): # noqa: F722 F821
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
start_mask = seq[None, :] >= start[:, None]
@ -58,8 +58,8 @@ def mask_from_start_end_indices(
def mask_from_frac_lengths(
seq_len: int["b"], frac_lengths: float["b"]
): # noqa: F722 F821
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
@ -71,8 +71,8 @@ def mask_from_frac_lengths(
def maybe_masked_mean(
t: float["b n d"], mask: bool["b n"] = None
) -> float["b d"]: # noqa: F722
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
) -> float["b d"]:
if not exists(mask):
return t.mean(dim=1)