mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Remove unnecessary copy
This commit is contained in:
parent
f6619a0b20
commit
e718c7ac88
@ -569,23 +569,22 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
|
||||
values, indexes = compute_k_largest(P, K)
|
||||
prod_values, prod_indexes = compute_products(values, indexes)
|
||||
|
||||
# combined_values, combined_indexes: (B, K) these are the top-K
|
||||
# combined_values, combined_indexes: (*, K) these are the top-K
|
||||
# most-probable combinations of (integerized probabilities and their
|
||||
# indexes, from largest to smallest probability
|
||||
combined_values, combined_indexes = compute_k_largest(prod_values, K)
|
||||
|
||||
# let combined_indexes contain the original N-tuples
|
||||
combined_indexes_shape = list(combined_indexes.shape) + [N]
|
||||
# combined_indexes: (B, K, N)
|
||||
# combined_indexes: (*, K, N)
|
||||
combined_indexes = torch.gather(prod_indexes, dim=-2,
|
||||
index=combined_indexes.unsqueeze(-1).expand(combined_indexes_shape))
|
||||
|
||||
P_cumsum = torch.cumsum(P, dim=-1) # (B, N, M)
|
||||
P_cumsum_cat = torch.cat((torch.zeros(*P_cumsum.shape[:-1], 1, dtype=P_cumsum.dtype,
|
||||
device=P_cumsum.device),
|
||||
P_cumsum), dim=-1)
|
||||
|
||||
P_cumsum_cat = torch.zeros(*P.shape[:-1], M+1, dtype=P.dtype,
|
||||
device=P.device) # # (*, N, M+1)
|
||||
P_cumsum = torch.cumsum(P, dim=-1, out=P_cumsum_cat[...,1:]) # (*, N, M)
|
||||
P_cumsum_exclusive = P_cumsum_cat[...,:-1]
|
||||
P_cumsum = P_cumsum_cat[...,1:]
|
||||
|
||||
# P_sum is the total sum of the individual softmaxes/distributions.
|
||||
# Shape: (*, N)
|
||||
@ -596,10 +595,10 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
|
||||
|
||||
# P_sum_product is the inclusive cumulative product of P_sum, multiplied
|
||||
# over the N axis.
|
||||
# Shape: (B,)
|
||||
# Shape: (*,)
|
||||
P_sum_cumprod = torch.cumprod(P_sum, dim=-1)
|
||||
# P_prev_sum_cumprod is the exclusive-product versin of P_sum_cumprod, i.e.
|
||||
# contains the product over previous elements of P_sum. Shape: (B,)
|
||||
# contains the product over previous elements of P_sum. Shape: (*,)
|
||||
P_sum_product = P_sum_cumprod[...,-1]
|
||||
P_prev_sum_cumprod = P_sum_cumprod // P_sum
|
||||
|
||||
@ -608,7 +607,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
|
||||
P_cumsum_exclusive_scaled = P_cumsum_cat_scaled[...,:-1]
|
||||
P_cumsum_scaled = P_cumsum_cat_scaled[...,1:]
|
||||
|
||||
# combined_cumsums: (B, K)
|
||||
# combined_cumsums: (*, K)
|
||||
combined_cumsums = get_combined_cumsums(P,
|
||||
P_cumsum_exclusive_scaled,
|
||||
combined_indexes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user