Remove unnecessary copy

This commit is contained in:
Daniel Povey 2022-04-25 20:41:00 +08:00
parent f6619a0b20
commit e718c7ac88

View File

@ -569,23 +569,22 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
values, indexes = compute_k_largest(P, K)
prod_values, prod_indexes = compute_products(values, indexes)
# combined_values, combined_indexes: (B, K) these are the top-K
# combined_values, combined_indexes: (*, K) these are the top-K
# most-probable combinations of (integerized probabilities and their
# indexes, from largest to smallest probability
combined_values, combined_indexes = compute_k_largest(prod_values, K)
# let combined_indexes contain the original N-tuples
combined_indexes_shape = list(combined_indexes.shape) + [N]
# combined_indexes: (B, K, N)
# combined_indexes: (*, K, N)
combined_indexes = torch.gather(prod_indexes, dim=-2,
index=combined_indexes.unsqueeze(-1).expand(combined_indexes_shape))
P_cumsum = torch.cumsum(P, dim=-1) # (B, N, M)
P_cumsum_cat = torch.cat((torch.zeros(*P_cumsum.shape[:-1], 1, dtype=P_cumsum.dtype,
device=P_cumsum.device),
P_cumsum), dim=-1)
P_cumsum_cat = torch.zeros(*P.shape[:-1], M+1, dtype=P.dtype,
device=P.device) # # (*, N, M+1)
P_cumsum = torch.cumsum(P, dim=-1, out=P_cumsum_cat[...,1:]) # (*, N, M)
P_cumsum_exclusive = P_cumsum_cat[...,:-1]
P_cumsum = P_cumsum_cat[...,1:]
# P_sum is the total sum of the individual softmaxes/distributions.
# Shape: (*, N)
@ -596,10 +595,10 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
# P_sum_product is the inclusive cumulative product of P_sum, multiplied
# over the N axis.
# Shape: (B,)
# Shape: (*,)
P_sum_cumprod = torch.cumprod(P_sum, dim=-1)
# P_prev_sum_cumprod is the exclusive-product versin of P_sum_cumprod, i.e.
# contains the product over previous elements of P_sum. Shape: (B,)
# contains the product over previous elements of P_sum. Shape: (*,)
P_sum_product = P_sum_cumprod[...,-1]
P_prev_sum_cumprod = P_sum_cumprod // P_sum
@ -608,7 +607,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
P_cumsum_exclusive_scaled = P_cumsum_cat_scaled[...,:-1]
P_cumsum_scaled = P_cumsum_cat_scaled[...,1:]
# combined_cumsums: (B, K)
# combined_cumsums: (*, K)
combined_cumsums = get_combined_cumsums(P,
P_cumsum_exclusive_scaled,
combined_indexes)