mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Update utils.py
This commit is contained in:
parent
0f75112385
commit
46c077081b
@ -38,7 +38,7 @@ def default(v, d):
|
||||
|
||||
|
||||
def lens_to_mask(
|
||||
t: int["b"], length: int | None = None # noqa: F722 F821
|
||||
t: int["b"], length: int | None = None # noqa: F722 F821
|
||||
) -> bool["b n"]: # noqa: F722 F821
|
||||
if not exists(length):
|
||||
length = t.amax()
|
||||
@ -48,7 +48,7 @@ def lens_to_mask(
|
||||
|
||||
|
||||
def mask_from_start_end_indices(
|
||||
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
|
||||
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
|
||||
):
|
||||
max_seq_len = seq_len.max().item()
|
||||
seq = torch.arange(max_seq_len, device=start.device).long()
|
||||
@ -58,7 +58,7 @@ def mask_from_start_end_indices(
|
||||
|
||||
|
||||
def mask_from_frac_lengths(
|
||||
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
|
||||
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
|
||||
):
|
||||
lengths = (frac_lengths * seq_len).long()
|
||||
max_start = seq_len - lengths
|
||||
@ -71,7 +71,7 @@ def mask_from_frac_lengths(
|
||||
|
||||
|
||||
def maybe_masked_mean(
|
||||
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
|
||||
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
|
||||
) -> float["b d"]: # noqa: F722 F821
|
||||
if not exists(mask):
|
||||
return t.mean(dim=1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user