Update utils.py

This commit is contained in:
zr_jin 2025-01-27 16:11:37 +08:00 committed by GitHub
parent 0f75112385
commit 46c077081b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,7 +38,7 @@ def default(v, d):
def lens_to_mask(
t: int["b"], length: int | None = None # noqa: F722 F821
t: int["b"], length: int | None = None # noqa: F722 F821
) -> bool["b n"]: # noqa: F722 F821
if not exists(length):
length = t.amax()
@ -48,7 +48,7 @@ def lens_to_mask(
def mask_from_start_end_indices(
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
@ -58,7 +58,7 @@ def mask_from_start_end_indices(
def mask_from_frac_lengths(
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
@ -71,7 +71,7 @@ def mask_from_frac_lengths(
def maybe_masked_mean(
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
) -> float["b d"]: # noqa: F722 F821
if not exists(mask):
return t.mean(dim=1)