Remove unnecessary check

This commit is contained in:
Daniel Povey 2022-04-25 20:37:06 +08:00
parent 7d457a7781
commit f6619a0b20

View File

@ -570,7 +570,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
prod_values, prod_indexes = compute_products(values, indexes)
# combined_values, combined_indexes: (B, K) these are the top-K
# most-probable combinations of (integerized_ probabilities and their
# most-probable combinations of (integerized probabilities and their
# indexes, from largest to smallest probability
combined_values, combined_indexes = compute_k_largest(prod_values, K)
@ -651,8 +651,9 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
samples)
# TODO: could remove the next call
check_shifted_samples(combined_cumsums, delta_P,
shifted_samples, P_sum_product)
if random.random() < 0.01:
check_shifted_samples(combined_cumsums, delta_P,
shifted_samples, P_sum_product)
indexes = get_indexes_for_samples(P, P_cumsum,
P_cumsum_exclusive,