From e718c7ac8863471c3595db25e1ea1762939cacdc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Apr 2022 20:41:00 +0800 Subject: [PATCH] Remove unnecessary copy --- .../ASR/pruned2_knowledge/sampling.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index cf5b09edc..b9c2703b4 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -569,23 +569,22 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens values, indexes = compute_k_largest(P, K) prod_values, prod_indexes = compute_products(values, indexes) - # combined_values, combined_indexes: (B, K) these are the top-K + # combined_values, combined_indexes: (*, K) these are the top-K # most-probable combinations of (integerized probabilities and their # indexes, from largest to smallest probability combined_values, combined_indexes = compute_k_largest(prod_values, K) # let combined_indexes contain the original N-tuples combined_indexes_shape = list(combined_indexes.shape) + [N] - # combined_indexes: (B, K, N) + # combined_indexes: (*, K, N) combined_indexes = torch.gather(prod_indexes, dim=-2, index=combined_indexes.unsqueeze(-1).expand(combined_indexes_shape)) - P_cumsum = torch.cumsum(P, dim=-1) # (B, N, M) - P_cumsum_cat = torch.cat((torch.zeros(*P_cumsum.shape[:-1], 1, dtype=P_cumsum.dtype, - device=P_cumsum.device), - P_cumsum), dim=-1) + + P_cumsum_cat = torch.zeros(*P.shape[:-1], M+1, dtype=P.dtype, + device=P.device) # # (*, N, M+1) + P_cumsum = torch.cumsum(P, dim=-1, out=P_cumsum_cat[...,1:]) # (*, N, M) P_cumsum_exclusive = P_cumsum_cat[...,:-1] - P_cumsum = P_cumsum_cat[...,1:] # P_sum is the total sum of the individual softmaxes/distributions. # Shape: (*, N) @@ -596,10 +595,10 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens # P_sum_product is the inclusive cumulative product of P_sum, multiplied # over the N axis. - # Shape: (B,) + # Shape: (*,) P_sum_cumprod = torch.cumprod(P_sum, dim=-1) # P_prev_sum_cumprod is the exclusive-product versin of P_sum_cumprod, i.e. - # contains the product over previous elements of P_sum. Shape: (B,) + # contains the product over previous elements of P_sum. Shape: (*,) P_sum_product = P_sum_cumprod[...,-1] P_prev_sum_cumprod = P_sum_cumprod // P_sum @@ -608,7 +607,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens P_cumsum_exclusive_scaled = P_cumsum_cat_scaled[...,:-1] P_cumsum_scaled = P_cumsum_cat_scaled[...,1:] - # combined_cumsums: (B, K) + # combined_cumsums: (*, K) combined_cumsums = get_combined_cumsums(P, P_cumsum_exclusive_scaled, combined_indexes)