mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update utils.py
This commit is contained in:
parent
0f75112385
commit
46c077081b
@ -38,7 +38,7 @@ def default(v, d):
|
|||||||
|
|
||||||
|
|
||||||
def lens_to_mask(
|
def lens_to_mask(
|
||||||
t: int["b"], length: int | None = None # noqa: F722 F821
|
t: int["b"], length: int | None = None # noqa: F722 F821
|
||||||
) -> bool["b n"]: # noqa: F722 F821
|
) -> bool["b n"]: # noqa: F722 F821
|
||||||
if not exists(length):
|
if not exists(length):
|
||||||
length = t.amax()
|
length = t.amax()
|
||||||
@ -48,7 +48,7 @@ def lens_to_mask(
|
|||||||
|
|
||||||
|
|
||||||
def mask_from_start_end_indices(
|
def mask_from_start_end_indices(
|
||||||
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
|
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
|
||||||
):
|
):
|
||||||
max_seq_len = seq_len.max().item()
|
max_seq_len = seq_len.max().item()
|
||||||
seq = torch.arange(max_seq_len, device=start.device).long()
|
seq = torch.arange(max_seq_len, device=start.device).long()
|
||||||
@ -58,7 +58,7 @@ def mask_from_start_end_indices(
|
|||||||
|
|
||||||
|
|
||||||
def mask_from_frac_lengths(
|
def mask_from_frac_lengths(
|
||||||
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
|
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
|
||||||
):
|
):
|
||||||
lengths = (frac_lengths * seq_len).long()
|
lengths = (frac_lengths * seq_len).long()
|
||||||
max_start = seq_len - lengths
|
max_start = seq_len - lengths
|
||||||
@ -71,7 +71,7 @@ def mask_from_frac_lengths(
|
|||||||
|
|
||||||
|
|
||||||
def maybe_masked_mean(
|
def maybe_masked_mean(
|
||||||
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
|
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
|
||||||
) -> float["b d"]: # noqa: F722 F821
|
) -> float["b d"]: # noqa: F722 F821
|
||||||
if not exists(mask):
|
if not exists(mask):
|
||||||
return t.mean(dim=1)
|
return t.mean(dim=1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user