From a266922678f260816b2a888d6123f1bedbc462fa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Apr 2022 22:29:11 +0800 Subject: [PATCH] First version of sampling.py, tests run. --- .../ASR/pruned2_knowledge/sampling.py | 1160 +++++++++++++++++ 1 file changed, 1160 insertions(+) create mode 100644 egs/librispeech/ASR/pruned2_knowledge/sampling.py diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py new file mode 100644 index 000000000..85df5a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -0,0 +1,1160 @@ +#!/usr/bin/env python3 + +# This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, +# its git history is there. + +import torch +from torch import Tensor +from torch import nn +from typing import Tuple, Optional +from scaling import ScaledLinear + +# The main export of this file is the function sample_combined(). + +# This file is not part of the implementation; it exists to test that the +# algorithms described in NOTES.md are correct. + + + +def compute_k_largest(X, K): + """ + Returns, for each row of X, the values and indexes of the K largest elements, + sorted from largest to smallest. + Args: + X: Tensor of any type, of shape (*, M) with M > K + K: an integer with 0 < K <= M + Returns (values, indexes), with + values: K most positive values of each row of X, shape (*, K) + indexes: indexes [on last axis] of K most positive values of each row of X, + shape (*, K) + """ + values, indexes = torch.sort(X, dim=-1, descending=True) + return values[...,:K], indexes[...,:K] + +def get_combined_cumsums(P, + P_cumsum_exclusive_scaled, + combined_indexes): + """ + This is a function called while sampling from a distribution that's a product of + N categorical distributions each of size M. + + Args: + P: Tensor of int64 of shape (*, N, M), containing the individual integerized + probabilities of classes. + P_cumsum_exclusive_scaled: scaled exclusive-sum version of P_cumsum which is cumulative + sum of P along M dimension, equal to + (P_cumsum - P) * prod_prev_totals + where prod_prev_totals is the product of the largest, final elements of P_cumsum + over previous n indexes) + combined_indexes: A tensor of int64 of shape (*, K, N), containing the top-K combinations + of indexes in {0,1,..,M-1} that have the most probability mass, from greatest to least. + We are interested in the (exclusive) cumulative sum at these points, i.e. for each index + in `combined_indexes` we are interested in the sum of all prior items. + + Returns: + Returns a Tensor of int64 of shape (*, K), returning the cumulative sum of all + combinations of indexes preceding each of the ones in 'combined_indexes'. + + We assign probability mass to combinations of indexes over the N axes, by + multipliying these integerized probabilities, and we're interested in the cumulative + sum of these products, assuming that index 0 varies fastest. So the (inclusive) cumsum of the + combination with indexes m0, m1, m2, would have a value given by: + P_cumsum[..., 0, m0] + P_cumsum[..., 1, m1] * sum0 + P_cumsum[..., 2, m2] * sum0 * sum1 + where sum0 is the total sum from cumsum0 (its last element), and so on. + """ + M = P.shape[-1] + N = P.shape[-2] + K = combined_indexes.shape[-2] + assert combined_indexes.shape[-1] == N + assert combined_indexes.shape[:-2] == P.shape[:-2] + + # ans: shape (*, K) + ans = torch.zeros(*combined_indexes.shape[:-1], dtype=P.dtype, device=P.device) + + # P_cumsum_selected_scaled, of shape (*, N, K), contains the individual looked-up + # exclusive-cumulative-sum values, i.e. the cumulative sum within the + # individual softmax/distribution, of all preceding items; + # these are pre-scaled by the product of total sum [P_sum] over previous + # n indexes. + P_cumsum_selected_scaled = P_cumsum_exclusive_scaled.gather(dim=-1, index=combined_indexes.transpose(-2, -1)) + + # P_selected, of shape (*, N, K) contains the individual probability values + # [corresponding to the indexes we want for the cumulative sum] + P_selected = P.gather(dim=-1, index=combined_indexes.transpose(-2, -1)) + + P_selected_cumprod = torch.cumprod(P_selected, dim=-2) + # P_selected_laterprod, of shape (*, N, K), contains the sum of + # P values for *later* n. + P_selected_laterprod = P_selected_cumprod[...,N-1:N,:] // P_selected_cumprod + + + # answer is sum over the N dimension, multipliying the + # indexes for n>0 by P_prev_sum_product, i.e. the product over previous + # sums. [Earlier indexes are considered to vary fastest, this was easiest + # to implement.] + # Shape: (*, K) + ans = (P_cumsum_selected_scaled * P_selected_laterprod).sum(dim=-2) + return ans + + +def compute_products(values, indexes): + """ + This is intended to be called on the outputs of compute_k_largest(). It computes the + products of different combinations of `values`, as follows: + + values: Tensor of shape (*, N, K) + indexes: Tensor of shape (*, N, K) + + The K refers to the K-best, e.g. K=4, which will have been computed by + compute_k_largest. `values` contains the K largest elements per row, of a source + tensor. We are computing all products of these, over the N axis. + + Returns: (values, indexes), where: + prod_values: Tensor of shape (*, K**N) containing the products of elements of `values`, + treating the dimensions in `*` as batch dimensions and taking products + along the N axis. + prod_indexes: Tensor of shape (*, K**N, N) containing the indexes of the original + elements that we took products of. + """ + assert values.shape == indexes.shape + K = values.shape[-1] + N = values.shape[-2] + + # assume (*) == (B,) and N==3 for example shapes. + # e.g. (B, 1, 1, 1) + unit_shape = list(values.shape[:-2]) + ([1] * N) + # e.g. (B, K, K, K) + full_shape = list(values.shape[:-2]) + ([K] * N) + # e.g. (B, K, K, K, N) + indexes_shape = list(values.shape[:-2]) + ([K] * N) + [N] + + prod_values = 1 + prod_indexes = torch.empty(*indexes_shape, dtype=indexes.dtype, + device=indexes.device) + + + for n in range(N): + shape = list(unit_shape) # copy it + shape[-N + n] = K # e.g. if n==1, shape might be (B, K, 1, 1) + this_values = values.select(dim=-2, index=n).reshape(shape) + this_src_indexes = indexes.select(dim=-2, index=n).reshape(shape) + this_dest_indexes = prod_indexes.select(dim=-1, index=n) # e.g. (B, K, K, K) + + this_dest_indexes[:] = this_src_indexes # will broadcast + prod_values = prod_values * this_values # will broadcast + + + values_shape = list(values.shape[:-2]) + [K**N] + indexes_shape = values_shape + [N] + return prod_values.reshape(values_shape), prod_indexes.reshape(indexes_shape) + + + +def compute_beta(P, K): + """ + See: ComputeBeta function [practical version] in NOTES.md. + Args: + P: a tensor of shape (*, M), in practice containing integers in {1,2,..2**31+1}, + but can be any integers >0 as far as this function is concerned, provided the + cumsum does not overflow. + K: an integer 0 < K < M + Returns a tensor of integers B of shape (*, 1) such that: + sum(min(P, B)) == K*B + [It will subtract a number in {0,1,..K-1} from one element of each row of P + to make this sum exact.] + """ + M = P.shape[-1] + R, R_indexes = torch.sort(P, dim=-1) # (*, M) + Q = torch.cumsum(R, dim=-1) + # Reference pseudocode was: + #for k in 0,1,...K-1, in any order: + # # B_k is the value of B if k indexes take the l.h.s. of the "min" expression in min(B, P) + # B_k = (Q[M-1-i] + K - k - 1) / (K - k) # the "+ K - k - 1" is to ensure we round up + # if R[M-1-k] >= B_k and P[I-2-k] <= B_k: + # return B_k + + temp = torch.arange(K+1, dtype=R.dtype, device=R.device) + # Kk, of shape (K,), contains [1, 2, ..., K], representing K-k for k = [K-1, K-2, ..., 0] + Kk = temp[1:K+1] + # Kk1 of shape (K,), contains [0, 1, ..., K-1], representing K-k-1 for k = [K-1, K-2, ..., 0] + Kk1 = temp[0:K] + + Q_part = Q[...,M-K:M] # represents: Q[...,M-1-k] for k = K-1,K-2,...,1,0 + + B_k = Q_part // Kk # shape (*, K) + remainder_k = Q_part - (B_k * Kk) # shape (*, K) + + large_int = (2**32 - 1) + R_part1 = torch.cat((R[...,M-K+1:M], torch.full((*R.shape[:-1], 1), large_int)), dim=-1) + R_part2 = R[...,M-K:M] + + # is_ok corresponds to: "(k==0 or R[M-k] > B_k) and R[M-1-k] <= B_k" in NOTES.md + is_ok = (torch.logical_and(R_part1 > B_k, R_part2 <= B_k)) # shape: (*, K) + + assert torch.all(torch.max(is_ok, dim=-1)[0] == 1) + B, indexes = torch.max(B_k * is_ok, dim=-1, keepdim=True) # shape: (*, 1) + remainder = torch.gather(remainder_k, dim=-1, index=indexes) + + remainder = torch.max(remainder_k * is_ok, dim=-1, keepdim=True)[0] # shape: (*, 1) + index = torch.max(R_indexes[...,M-K:M] * is_ok, dim=-1, keepdim=True)[0] + P_index = torch.gather(R_indexes[...,M-K:M], dim=-1, index=indexes) + P_val = torch.gather(P, dim=-1, index=P_index) + P_val -= remainder + P.scatter_(dim=-1, index=P_index, src=P_val) + + P_min = torch.minimum(P, B) + + P_min_sum = P_min.sum(dim=-1, keepdim=True) + assert torch.all(K * B == P_min_sum) + return B + +def compute_beta_prods(Psum, Ptop): + """ + Version of compute_beta() with a different interface, which is intended to work with + products of softmaxes. We are still assuming an integerized representation. + + Args: + Psum: Tensor of shape (*,), treated as the batch dimension, which contains, + as torch.int64, the total integerized probability mass taken as a product + along all dimension, e.g. for a tensor of shape (*, N, K) containing integerized + probabilities, we'd sum along the K dimension and take a product along the N + dimension. + Ptop: Tensor of shape (*, K), containing the probabilities for the top-K + possible outputs (each possible output is a combination of N indexes in + [0..M-1]). The sum of Ptop must be less than Psum. + + Returns: (B, delta_P) + beta: Tensor of shape (*) containing integers B satisfying: + sum(min(P, B)) == K*B (eqn:b1) + ... where conceptually, P is a matrix of shape (*, K**N) + that we do not materialize. + What this condition amounts to in terms of args of this function, + is that: + Psum + delta_P.sum(-1) = B*K + [Caution: the exact equality in (eqn:b1) is only true + once we subtract a small number in [0..K-1] from the next-largest + element of P that is not >B, to correct for rounding error; + this is accounted for in delta_P. + delta_P: of shape (*, K), this contains the change, if any, that we have + to make to the top-K elements of the distribution before sampling. + Satisfies delta_P <= 0. This combines two things: the + differences (min(P[i], B) - P[i]); and the values in [-(K-1)..0] + that we add to the largest item that's less than P to account + for rounding effects. + """ + K = Ptop.shape[-1] + assert Psum.shape == Ptop.shape[:-1] + + Ptop_cum = torch.cumsum(Ptop, dim=-1) # cumsum of Ptop, i.e. inclusive-sum. Shape (*, K) + + # add zero first element per row, so Ptop_cum_shift[...,0] is all-zeros and + # Ptop_cum_shift[...,1] contains the top-1. The idea is that + # Ptop_cum_shift[...,k] contains the sum of the top k items. + Ptop_cum_shift = torch.cat((torch.zeros(*Ptop.shape[:-1], 1, dtype=Ptop.dtype, + device=Ptop.device), + Ptop_cum[...,:K-1]), dim=-1) + # S1[...,k] contains, for each batch element, the sum of all but the k largest + # items. It corresponds to s-1 in the math of NOTES.md, see "ComputeBeta function + # [mathematical version]. + # Shape is (*, K) + S1 = Psum.unsqueeze(-1) - Ptop_cum_shift + + temp = torch.arange(K, -1, -1) # [K, K-1, ..., 0] + # Kk, of shape (K,), contains [K, K-1, ..., 1], representing K-k for k = [0, 1, ..., K-1] + Kk = temp[0:K] + # Kk1 of shape (K,), contains [K-1, K-2, ..., 0], representing K-k-1 for k = [0, 1, ..., K-1] + Kk1 = temp[1:K+1] + + # The following corresponds to: + # beta = (1 - s_k) / (K-k) + # in NOTES.md. This is integer division, we are rounding down. + # B_k[...,k] is the beta value if k values are >= beta. + B_k = S1 // Kk # shape (*, K) + remainder_k = S1 - (B_k * Kk) # shape (*, K) + + large_int = (2**63 - 1) + # Ptop_shifted is Ptop shifted right with a large value put first, i.e. + # instead of [top1, top2, top3, top4] we have [inf, top1, top2, top3] + Ptop_shifted = torch.cat((torch.full((*Ptop.shape[:-1], 1), large_int), + Ptop[...,:K-1]), dim=-1) + + + # is_ok corresponds to: "(k==0 or R[M-k] > B_k) and R[M-1-k] <= B_k" in NOTES.md + # It is true only for the "correct" k for each batch element, that corresponds + # to the number of values greater than B_k. + is_ok = (torch.logical_and(Ptop_shifted > B_k, Ptop <= B_k)) # shape: (*, K) + + # `indexes` are the values of k. + B, indexes = torch.max(B_k * is_ok, dim=-1) # shape: (*,) + + delta_P = (torch.minimum(Ptop, B.unsqueeze(-1)) - Ptop) - (remainder_k * is_ok) + + err = Psum + delta_P.sum(dim=-1) - B * K + assert torch.all(err == 0) + assert torch.all(torch.sum(is_ok, dim=-1)[0] == 1) + + return B, delta_P + +def compute_shifted_samples(combined_cumsums_mod: Tensor, + delta_P: Tensor, + samples: Tensor) -> Tensor: + """ + Modified randomly sampled values by adding values to correct for "disallowed regions", + i.e. parts of probability space that we skip because they correspond to a probability + mass greater than beta [or because they correspond to small padding for roundoff]. + + combined_cumsums_mod: Modified cumulative sums which when they were "combined_cumsums" + can be thought of as points in probability space, but when they become + "modified" are reduced to account for "disallowed regions" that + we cannot sample. The shape is (*, K) where `*` is the batch dimension + and K is the maximum number of "disallowed regions" + delta_P: negative values that correspond to the amount of probability mass we + removed for each "disallowed region", i.e. the size of those + regions, as a negative number. The shape is (*, K). + samples: The samples that we have to modify by adding values corresponding to + the widths of the appropriate disallowed regions. The shape is (*, K); + but this K is not the "same K" + Returns: shifted_samples, which will be the same shape as `samples`, but possibly + with larger values, i.e. shifted_samples >= samples + """ + samples = samples.unsqueeze(-1) + combined_cumsums_mod = combined_cumsums_mod.unsqueeze(-2) + delta_P = delta_P.unsqueeze(-2) + + # of shape (*, K, K), is_ge is True if sample k1 is >= combined_cumsum k2, + # meaning we need to add the corresponding delta_p. + is_ge = (samples >= combined_cumsums_mod) + + shifted_samples = samples - (is_ge * delta_P).sum(dim=-1, keepdim=True) + shifted_samples = shifted_samples.squeeze(-1) + return shifted_samples + +def check_shifted_samples(combined_cumsums: Tensor, + delta_P: Tensor, + shifted_samples: Tensor, + prod_cumsum: Tensor): + """ + Checks samples as modified by `compute_shifted_samples`: specifically, checks + that they are not in the "disallowed regions" that we are supposed to skip over. + + combined_cumsums: Cumulative sums which can be thought of as the start of + "disallowed regions" in probability space. Shape is (*, K) + delta_P: the negative of the size of "disallowed regions". Shape is (*, K) + shifted_samples: The samples as modified by `compute_shifted_samples`. None + of these should be within the "disallowed regions". Shape is (*, K); + but note, this K does not have a correpondence with the K in the + other two args' shapes. + prod_cumsum: The product of sums/normalizers of the different softmaxes, of + shape (*,); this can be thought of as the total size of the probability + space, including "disallowed regions". This is to check that + `shifted_samples` are less than this value. + """ + assert torch.all(torch.logical_and(shifted_samples >= 0, + shifted_samples < prod_cumsum.unsqueeze(-1))) + + shifted_samples = shifted_samples.unsqueeze(-1) + combined_cumsums = combined_cumsums.unsqueeze(-2) + delta_P = delta_P.unsqueeze(-2) + + disallowed_regions_start = combined_cumsums + disallowed_regions_end = combined_cumsums - delta_P # delta_p is <= 0. + + # in_disallowed_region is of shape (*, K, K) + in_disallowed_region = torch.logical_and(shifted_samples >= disallowed_regions_start, + shifted_samples < disallowed_regions_end) + assert torch.all(torch.logical_not(in_disallowed_region)) + + + +def get_indexes_for_samples(P: Tensor, + P_cumsum: Tensor, + P_cumsum_exclusive: Tensor, + shifted_samples: Tensor) -> Tensor: + """ + From K `shifted_samples` which are in the joint probability-space of N softmaxes + of size M, figure out which sample indexes they correspond to. + Args: + P: of shape (*, N, M), the original integerized probabilities we + are interested in the products over [i.e. over the N dimension], + e.g. N=2, M=128. + P_cumsum: Of shape (*, N, M), this is the (inclusive) cumulative sum of + the original integerized probabilities P. Conceptually, the entire + probability space is over all possible products, over the N + dimension, of different choices of m, arranged so that m-indexes for + the earlier n indexes vary fastest, like [000,100,200,010,110,210, ... ]. + P_cumsum_exclusive: Of shape (*, N, M), the exclusive-sum version of + P_cumsum, equivalent to P_cumsum - P. + shifted_samples: Of shape (*, K), contains the random samples we want + to find indexes for, "shifted" means we have skipped over "disallowed regions" + corresponding to combinations of indexes that had too much probability mass. + Will satisfy: + 0 <= shifted_samples < P_cumsum[...,-1].prod(dim=-1, keepdim=True) + Returns: + indexes: Of shape (*, K, N), the N-tuples of indexes in {0,1...M-1} + corresponding to each of the K samples. + """ + + # P_sum_cumprod is the cumulative product of the total sum of the original + # integerized probabilities P, of shape (*, M) + P_sum_cumprod = torch.cumprod(P_cumsum[...,-1], dim=-1) + M = P.shape[-1] + N = P.shape[-2] + + ans_indexes_shape = list(shifted_samples.shape) + [N] # (*, K, N) + ans_indexes = torch.empty(*ans_indexes_shape, dtype=P.dtype, + device=P.device) + + cur_samples = shifted_samples # (*, K) + for n in range(N-1, -1, -1): # [N-1, N-2, ..., 0] + this_samples = cur_samples # (*, K) + if n > 0: + # divide by the total product of probs *previous* indexes n, + # so we can compare directly with P_cumsum. + this_samples = this_samples // P_sum_cumprod[...,n-1:n] + # right=True means we find + # P_cumsum[...,index-1] <= this_samples[...,k] < P_cumsum[...,index], + # which is what we want, as opposed to ... < ... <= (i.e. swap < and <=) + idx = ans_indexes[...,n] = torch.searchsorted(P_cumsum[...,n,:], # (*, M) + this_samples, # (*, K) + right=True) + this_P = torch.gather(P[...,n,:], dim=-1, index=idx) # shape: (*, K) + + if n == 0: + break + + # get cumsum corresponding to the indexes we just computed, we need + # to subtract the start of the region corresponding to this index. + # need exclusive-sum here.. + cur_cumsum = torch.gather(P_cumsum_exclusive[...,n,:], dim=-1, index=idx) + # account for the product of previous dims' total sums... + # TODO: multiply P_cumsum by P_sum_cumprod + cur_cumsum *= P_sum_cumprod[...,n-1:n] + # Get the remainder after subtracting the indexes we just worked out, + # this will be used to get previous indexes, i.e. for lower n. + remainder = cur_samples - cur_cumsum + # Also divide by this_P, since all probability masses corresponding + # to this index we just worked out will be scaled by this amount. + remainder = remainder // this_P + cur_samples = remainder + + return ans_indexes + +def get_weights_for_samples(P: Tensor, + P_sum_product: Tensor, + B: Tensor, + indexes: Tensor, + dtype: torch.dtype) -> Tensor: + """ + Return output weights for the K samples we selected for each distribution. + The probability of selecting a particular sample with probability p_i + is: min(1, p_i/beta), and the output weight for a sample (if we select it) + will be p_i divided by the probability with which we sampled it, + i.e. p_i / min(1, p_i/beta) = max(p_i, beta). P and B are integerized + forms of p and beta, we have to divide by P_sum_product to get + the actual values. + + Args: + P: integerized probabilities for the individual distributions + in our product-of-distributions, of shape + (*, N, M), where * is the batch dimension(s), N is the + number of distributions in the product (e.g. 2 or 3), and + M is the size of each distribution (e.g. 128). + P_sum_product: of shape (*,) the result of taking the sum of + P over the M dimension and then the product over the N + dimension. + B: of shape (*,), the integerized value of beta + (B/P_sum_product == beta). We sample each item with + probability min(1, prob_of_item/beta), with beta + chosen such that the sum of those probabilities is + exactly K + indexes: the indexes of the chosen samples, of + shape (*, K, N). K is the number of samples; + each sample is an N-tuple of indexes. + dtype: the desired data-type of the returned probabilities. + Returns: + Returns the probabilities for the chosen indexes, of + shape (*, K); these will sum to one along the K axis. + """ + if dtype == torch.float16: + return get_weights_for_samples(P, P_sum_product, + B, indexes, torch.float32).to(dtype) + assert dtype in [torch.float32, torch.float64] + + # probs: of shape (*, N, K), the integer probabilities for + # the individual distributions + probs = torch.gather(P, dim=-1, index=indexes.transpose(-2, -1)) + + # multiply probs across the N axis to get products of shape (*, K) + probs = probs.prod(dim=-2) + + # P_sum_product: (*,) + P_sum_product = P_sum_product.to(dtype=dtype) + # beta: (*,) + beta = B.to(dtype=dtype) / P_sum_product + p = probs.to(dtype=dtype) / P_sum_product.unsqueeze(-1) + # ans: shape (*, K) + ans = torch.maximum(p, beta.unsqueeze(-1)) + # ans_sum: shape (*,) + ans_sum = ans.sum(dim=-1) + assert torch.all((ans_sum - 1.0).abs() < 0.01) + return ans + + +_max_bits = 54 # used in sample_combined_forward and sample_combined_backward, + # see comment in sample_combined_forward. + +def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]: + """ + Sample from a distribution that is the product of softmaxes. We will sample + K *distinct* samples. This entails using sampling weights of the form min(1, p/beta) + for a computed beta. + Args: + p: A Tensor of shape (*, N, M): either normalized log-probs (if input_is_log==False), + or normalized probabilities; normalized along the M axis. M must be + a power of 2, and N must be in [1,2,3,4]. + K: An integer, the number of samples required, with 0 < K < N + input_is_log: True if p represents normalized log-probs, False if it represents + probabilities. + + Returns: (indexes, weights) + indexes: of shape (*, K, N), for each of K samples from a distribution it contains + an N-tuple of indexes saying which combination of indexes from the + component distributions were sampled. + weights: of shape (*, K), gives the weight associated with each sample, + which will equal max(p, beta) for a beta specific to the batch element, + i.e. to the product of the distributions (0 < beta <= 1/K). The + weights will sum to 1 along the K axis. + """ + p = p.detach() # call sample_combined() if you need derivatives. + N = p.shape[-2] + M = p.shape[-1] + assert M & (M-1) == 0 # required for the random reordering to work (see + # rand_perm), this ensures odd numbers would be + # coprime to M. + dtype = p.dtype + assert N > 0 and N <= 4 + # allocating 54 bits for the product of distributions means that, for instance, + # with 3 distributions we can have 18 bits per distribution. The reason + # we don't go closer to 64, is that to choose random numbers we + # do: `b = torch.randint((2**63 - 1), B.shape) % B`, and for this to actually + # be uniformly distributed we need 2**63 - 1 to be substantially larger than + # the total probability mass. However it's not super-critical that this + # gap be very large because in any case we randomize the order of indexes + # before the sampling procedure. + num_bits_per_sample = _max_bits // N + + if input_is_log: + p = p.exp() + + # rand_perm is in {1,3,..M-1}, it is of shape (*, N, 1); we'll + # use it to pseudo-randomly reorder each distribution. + rand_perm = torch.randint(M//2, p.shape[:-1] + (1,), device=p.device) * 2 + 1 + # Note: we could implement this more efficiently with a special kernel. + rand_perm_indexes = (rand_perm * torch.arange(M, device=p.device)) % M + # reorder the elements of p; we'll correct for the reordering later when + # we return indexes. + p = torch.gather(p, dim=-1, index=rand_perm_indexes) + + + # the + 1 is because we need all elements of P to be nonzero (this will avoid + # some nasty edge cases) + P = (p * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64) + values, indexes = compute_k_largest(P, K) + prod_values, prod_indexes = compute_products(values, indexes) + + # combined_values, combined_indexes: (B, K) these are the top-K + # most-probable combinations of (integerized_ probabilities and their + # indexes, from largest to smallest probability + combined_values, combined_indexes = compute_k_largest(prod_values, K) + + # let combined_indexes contain the original N-tuples + combined_indexes_shape = list(combined_indexes.shape) + [N] + # combined_indexes: (B, K, N) + combined_indexes = torch.gather(prod_indexes, dim=-2, + index=combined_indexes.unsqueeze(-1).expand(combined_indexes_shape)) + + P_cumsum = torch.cumsum(P, dim=-1) # (B, N, M) + P_cumsum_cat = torch.cat((torch.zeros(*P_cumsum.shape[:-1], 1, dtype=P_cumsum.dtype, + device=P_cumsum.device), + P_cumsum), dim=-1) + P_cumsum_exclusive = P_cumsum_cat[...,:-1] + P_cumsum = P_cumsum_cat[...,1:] + + # P_sum is the total sum of the individual softmaxes/distributions. + # Shape: (*, N) + P_sum = P_cumsum[..., M-1] + # P_prev_sum_product, of shape (*, N) contains the product of all the P_sum + # values for the *previous* indexes n, i.e, over n_prev < n. We divide by + # P_sum to make it an exclusive, not an inclusive, product. + + # P_sum_product is the inclusive cumulative product of P_sum, multiplied + # over the N axis. + # Shape: (B,) + P_sum_cumprod = torch.cumprod(P_sum, dim=-1) + # P_prev_sum_cumprod is the exclusive-product versin of P_sum_cumprod, i.e. + # contains the product over previous elements of P_sum. Shape: (B,) + P_sum_product = P_sum_cumprod[...,-1] + P_prev_sum_cumprod = P_sum_cumprod // P_sum + + + P_cumsum_cat_scaled = P_cumsum_cat * P_prev_sum_cumprod.unsqueeze(-1) + P_cumsum_exclusive_scaled = P_cumsum_cat_scaled[...,:-1] + P_cumsum_scaled = P_cumsum_cat_scaled[...,1:] + + # combined_cumsums: (B, K) + combined_cumsums = get_combined_cumsums(P, + P_cumsum_exclusive_scaled, + combined_indexes) + + B, delta_P = compute_beta_prods(P_sum_product, combined_values) + + + # reorder combined_cumsums from smallest to largest, which we'll require + # when interpolating the "skipped regions" into the random numbers. + combined_cumsums, reorder_indexes = torch.sort(combined_cumsums, dim=-1) + # also reorder delta_P [so that delta_P and combined_cumsums are reordered + # in the same way] + delta_P = torch.gather(delta_P, dim=-1, index=reorder_indexes) + + # delta_P_exclusive, of shape (*, K), is the exclusive cumulative sum of + # delta_P, containing negative values. + delta_P_cumsum = torch.cumsum(delta_P, dim=-1) + delta_P_exclusive = delta_P_cumsum - delta_P + + # combined_cumsums_mod is combined_cumsums modified by adding the product + # of previous delta_P's (which will be negative). This compensates for + # the fact that the random numbers in "sampled_values" are in a compressed + # space where we "skip over" regions of size -delta_P. + # + # These are the cutoffs for subtracting the delta_P's + # from sampled_values + combined_cumsums_mod = combined_cumsums + delta_P_exclusive + + + # CAUTION: if the product of sums is too large, this rand_values + # will not be sufficiently + # random!! We need to leave some headroom. + # rand_values are random in {0, 1, ..., B-1} + rand = torch.randint((2**63 - 1), B.shape) % B + # rand, rand + B, rand + 2B, ...., rand + (K-1)B + samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K) + + shifted_samples = compute_shifted_samples(combined_cumsums_mod, + delta_P, + samples) + + # TODO: could remove the next call + check_shifted_samples(combined_cumsums, delta_P, + shifted_samples, P_sum_product) + + indexes = get_indexes_for_samples(P, P_cumsum, + P_cumsum_exclusive, + shifted_samples) + + weights = get_weights_for_samples(P, P_sum_product, B, indexes, + dtype=p.dtype) + + indexes = (indexes * rand_perm.transpose(-2, -1)) % M + + return weights, indexes + +def sample_combined_backward(p: Tensor, input_is_log: bool, indexes: Tensor, + weights: Tensor, weights_grad: Tensor) -> Tensor: + """ + Backward for sample_combined(); see sample_combined_forward() for detailed docs on + the forward pass. Notice that we don't use Torch's inbuilt autograd for this; + that would not give us the answer we want. + + View the output of the forward pass as a sparse vector q. You can view the + forward pass as implementing: q = z p, where z is a sparse vector whose + *expected* value is [1,1,..]. Because the expected value of z does not change + with p, we treat z as being independent of p, even though actually + the detailed distribution of z does depend on p. So the backprop in non-log + space would just be: + p_grad = z * output_grad + where z is the sparse vector we multiplied by in the forward pass. Since + we can express z as just q / p, this becomes: + p_grad = q / p * output_grad + where q is the sparse output of the forward pass. In log-space, this is just + equivalent to log_p_grad = log_output_grad. + In non-log space, division by p could lead to infinite output if p is zero; + in the forward pass we smoothed p by adding 2**-(num_bits_per_sample), and + if you work it out, the backprop rule correcting for this would just become + p_grad = q / (p + 2**-(num_bits_per_sample) * output_grad + + Args: + p: the probabilities as used in the forward pass, of shape (*, N, M) + input_is_log: if False, p should be probabilities; if True, p should + be normalized log-probs, e.g. the output of log_softmax. + weights: the `weights` output of simple_combined_forward, of shape (*, K) + indexes: the `indexes` output of simple_combined_forward, of shape (*, K, N) + weights_grad: the loss-function gradient w.r.t the output weights, of shape + (*, K) + """ + K = weights.shape[-1] + N = indexes.shape[-1] + + log_p_grad = torch.zeros_like(p) # (*, N, M) + # log_weights_grad is derivative w.r.t. log(weights). + log_weights_grad = weights_grad * weights + # expanded_log_weights_grad: (*, N, K), + # duplicate along the N dimension + expanded_log_weights_grad = log_weights_grad.unsqueeze(-2).expand(*weights.shape[:-1], + N, K) + log_p_grad.scatter_add_(dim=-1, index=indexes.transpose(-2, -1), src=expanded_log_weights_grad) + + if not input_is_log: + if p.dtype == torch.float16: + raise ValueError("For float16 input you have to use log-space for input probabilities, " + "require input_is_log=True") + num_bits_per_sample = _max_bits // N + # 2**-num_bits_per_sample is very small, so don't worry about renormalizing p. + # This is just to stop division by zero. + p_smoothed = p + (2.0**-num_bits_per_sample) + log_p_grad.divide_(p_smoothed) + return log_p_grad + return log_p_grad + +class SampleCombinedFunction(torch.autograd.Function): + # please see sample_combined() or sample_combined_forward() or + # sample_combined_backward() for documentation + @staticmethod + def forward(ctx, p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]: + with torch.no_grad(): + weights, indexes = sample_combined_forward(p, K, input_is_log) + ctx.save_for_backward(p, indexes, weights) + ctx.input_is_log = input_is_log + return weights, indexes + + @staticmethod + def backward(ctx, weights_grad: Optional[Tensor], indexes_grad: Optional[Tensor]) -> Tuple[Tensor, None, None]: + p, indexes, weights = ctx.saved_tensors + p_grad = sample_combined_backward(p, ctx.input_is_log, indexes, + weights, weights_grad) + return p_grad, None, None + + +def sample_combined(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]: + """ + Sample from a distribution that is the product of softmaxes. We will sample + K *distinct* samples. This entails using sampling weights of the form min(1, p/beta) + for a computed beta. + Args: + p: A Tensor of shape (*, N, M): either normalized log-probs (if input_is_log==False), + or normalized probabilities; normalized along the M axis. M must be + a power of 2, and N must be in [1,2,3,4]. + K: An integer, the number of samples required, with 0 < K < N + input_is_log: True if p represents normalized log-probs, False if it represents + probabilities. + + Returns: (weights, indexes) + weights: of shape (*, K), gives the weight associated with each sample, + which will equal max(p, beta) for a beta specific to the batch element, + i.e. to the product of the distributions (0 < beta <= 1/K). The + weights will sum to 1 along the K axis. + indexes: of shape (*, K, N), for each of K samples from a distribution it contains + an N-tuple of indexes saying which combination of indexes from the + component distributions were sampled. + """ + return SampleCombinedFunction.apply(p, K, input_is_log) + + + +def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]: + """ + Forward function for soft sampling. + Args: + p: Tensor of shape (*, M) + K: number of samples, 1 <= K < M + input_is_log: if true, p must be probabilities in [0..1] that sum to one; + if false, p must be logprobs (that sum to one after exp()) + Returns: (indexes, y), where: + indexes: shape (*, K), a LongTensor containing elements in [0..M-1], distinct + along the K axis + y: shape (*, K), a Tensor containing values in [0..1], which sum to 1 along the + K axis. + + Search for "def soft_sample" in NOTES.md to understand this. + """ + if input_is_log: + p = p.exp() + M = p.shape[-1] + assert M & (M-1) == 0 + two31 = 2 ** 31 # TEMP for testing, should be 2**31 + # to(dtype=this rounds toward 0, which is good enough + P = (p*two31 + 1).to(dtype=torch.long) + B = compute_beta(P, K) + beta = B / two31 + t = torch.randint(M//2, p.shape[:-1] + (1,)) # shape: *, 1 + s = t * 2 + 1 + #s = torch.ones_like(t) + + # turns out we don't need inv_s. + inv_s = (s ** (M//2 - 1)) % M + assert torch.all((s * inv_s) % M == 1) # if this fails, check that M is a power of 2 + + # R = pseudo-random re-ordering of p. + R = torch.minimum(torch.gather(P, dim=-1, index=(s * torch.arange(M)) % M), + B) + # S = inclusive-sum of R + S = torch.cumsum(R, dim=-1) + + # Let b be a random integer drawn uniformly from {0, 1, ..., B-1}. + b = torch.randint((2**63 - 1), B.shape) % B + + S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1), S[...,:-1]), dim=-1) + + k_prev = (S_prev + b) // B + k_cur = (S + b) // B + # if S_prev >= b and k_cur > k_prev:.. don't need S_prev >= b because rounded down. + is_ok = (k_cur > k_prev) + + # sort so the "false" goes first and the "true" goes in last K indexes. + values, indices = is_ok.sort(dim=-1) + i = indices[...,M-K:M] + i = (i * s) % M # Reverse the pseudo-random reordering + y = torch.maximum(torch.gather(p, dim=-1, index=i), beta) + assert torch.all(is_ok.sum(dim=-1) == K) + assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.01) + + + +def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: + std = 0.1 + a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M ** N, D)) + nn.init.uniform_(ans, -a, a) + return ans + +def join_indexes(indexes: Tensor, M: int) -> Tensor: + """ + Combines N-tuples of indexes into single indexes that can be used for + lookup in the knowledge base. Args: + indexes: tensor of torch.int64 of shape (*, K, N), with elements in + {0..M-1} + M: the size of the original softmaxes, is upper bound on elements + in indexes + Returns: + joined_indexes: of shape (*, K), joined_indexes[...,k] equals + joined_indexes[...,0,k] + joined_indexes[...,1,k]*(M**1) ... + joined_indexes[...,1,k]*(M**(N-1))] + """ + N = indexes.shape[-1] + n_powers = M ** torch.arange(N, device=indexes.device) # [ 1, M, ..., M**(N-1) ] + return (indexes * n_powers).sum(dim=-1) + + +def weighted_matrix_lookup(weights: Tensor, + indexes: Tensor, + knowledge_base: Tensor) -> Tensor: + """ + Weighted combination of specified rows of a matrix. + weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. + indexes: Tensor of shape (*, K), with elements in [0..C-1] + knowledge_base: Tensor of shape (C-1, D), whose rows we'll be looking up + Returns: + tensor of shape (*, D), containing weighted sums of rows of + `knowledge_base` + """ + lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) + D = knowledge_base.shape[-1] + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) + assert list(ans.shape) == list(weights.shape[:-2]) + [D] + return ans + + +class WeightedMatrixLookupFunction(torch.autograd.Function): + """ + Weighted matrix lookup, memory efficient version that redoes the computation in the + backward pass... this is not really optimal but the autograd for this operation is + complicated. + + See weighted_matrix_lookup() for documentation. + """ + @staticmethod + def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + ctx.save_for_backward(weights.detach(), indexes.detach(), + knowledge_base.detach()) + return weighted_matrix_lookup(weights, indexes, knowledge_base) + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: + weights, indexes, knowledge_base = ctx.saved_tensors + weights.requires_grad = True + knowledge_base.requires_grad = True + with torch.enable_grad(): + ans = weighted_matrix_lookup(weights, indexes, knowledge_base) + ans.backward(gradient=ans_grad) + return weights.grad, None, knowledge_base.grad + + +class KnowledgeBaseLookup(nn.Module): + """ + Create knowledge-base lookup module. (The knowledge-base parameter, which is + large, is shared between these modules). + Args: + M: int, softmax size, e.g. in [32..128] + N: int, number of softmaxes, in [2..3] + D: int, embedding dimension in knowledge base, e.g. 256 + K: number of samples (affects speed/accuracy tradeoff), e.g. 16. + embedding_dim: the dimension to project from and to, e.g. the + d_model of the conformer. + """ + def __init__(self, M: int, N: int, D: int, + K: int, embedding_dim: int, + knowledge_base: nn.Parameter): + super(KnowledgeBaseLookup, self).__init__() + self.knowledge_base = knowledge_base # shared! + self.in_proj = ScaledLinear(embedding_dim, M * N) + # initial_scale = 4.0 because the knowlege_base activations are + # quite small -- if we use our optimizer they'll have stddev <= 0.1. + self.out_proj = ScaledLinear(D, embedding_dim, + initial_scale = 10.0) + self.M = M + self.N = N + self.K = K + + def forward(self, x: Tensor) -> Tensor: + """ + Forward function that does knowledge-base lookup. + Args: + x: input, of shape (*, E) where E is embedding_dim + as passed to constructor + y: output of knowledge-base lookup, of shape (*, E) + + # TODO: later we can try multiplying by a projection of x or something like that. + """ + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + weights, indexes, = sample_combined(x, self.K, input_is_log=True) + indexes = join_indexes(indexes, self.M) + x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) + return x + + +def _test_compute_beta(): + # use a small M-- 8 here-- because it's more likely to + # choose k != 0 in compute_beta(), giving a more complete test. + a = torch.randint(low=1, high=65535, size=(9, 16)) + K = 4 + beta = compute_beta(a, K) # it checks its own answer.. + print("beta = ", beta) + + +def _test_soft_sample(): + l = 2 * torch.randn(6, 64) + p = torch.softmax(l, dim=-1) + soft_sample_forward(p, K=4, input_is_log=False) + +def _test_combined(): + N = 2 + K = 4 + M = 8 + + P = ((5 * torch.randn(2, N, M)).softmax(dim=-1) * 16 + 1).to(dtype=torch.int64) + + print("P = ", P) + values, indexes = compute_k_largest(P, K) + print("largest values = ", values) + print("largest indexes = ", indexes) + prod_values, prod_indexes = compute_products(values, indexes) + assert prod_values.shape == prod_indexes.shape[:-1] + print("prod_values = ", prod_values) + print("prod_indexes = ", prod_indexes) + + # combined_values, combined_indexes: (B, K) these are the top-K + # most-probable combinations of (integerized_ probabilities and their + # indexes, from best to worst. + combined_values, combined_indexes = compute_k_largest(prod_values, K) + + combined_indexes_shape = list(combined_indexes.shape) + [N] + # combined_indexes: (B, K, N) + combined_indexes = torch.gather(prod_indexes, dim=-2, + index=combined_indexes.unsqueeze(-1).expand(combined_indexes_shape)) + + print("combined_values = ", combined_values) + print("combined_indexes = ", combined_indexes) + + + P_cumsum = torch.cumsum(P, dim=-1) # (B, N, M) + P_cumsum_cat = torch.cat((torch.zeros(*P_cumsum.shape[:-1], 1, dtype=P_cumsum.dtype, + device=P_cumsum.device), + P_cumsum), dim=-1) + P_cumsum_exclusive = P_cumsum_cat[...,:-1] + P_cumsum = P_cumsum_cat[...,1:] + + # P_sum is the total sum of the individual softmaxes/distributions. + # Shape: (*, N) + P_sum = P_cumsum[..., M-1] + # P_prev_sum_product, of shape (*, N) contains the product of all the P_sum + # values for the *previous* indexes n, i.e, over n_prev < n. We divide by + # P_sum to make it an exclusive, not an inclusive, product. + + # P_sum_product is the inclusive cumulative product of P_sum, multiplied + # over the N axis. + # Shape: (B,) + P_sum_cumprod = torch.cumprod(P_sum, dim=-1) + # P_prev_sum_cumprod is the exclusive-product versin of P_sum_cumprod, i.e. + # contains the product over previous elements of P_sum. Shape: (B,) + P_sum_product = P_sum_cumprod[...,-1] + print("P_sum_product = ", P_sum_product) + P_prev_sum_cumprod = P_sum_cumprod // P_sum + + + P_cumsum_cat_scaled = P_cumsum_cat * P_prev_sum_cumprod.unsqueeze(-1) + P_cumsum_exclusive_scaled = P_cumsum_cat_scaled[...,:-1] + P_cumsum_scaled = P_cumsum_cat_scaled[...,1:] + + # combined_cumsums: (B, K) + combined_cumsums = get_combined_cumsums(P, + P_cumsum_exclusive_scaled, + combined_indexes) + print("combined_cumsums = ", combined_cumsums) + print("combined_cumsums + combined_values= ", combined_cumsums + combined_values) + + + assert torch.all(P_sum_product.unsqueeze(-1) > combined_cumsums) + + assert torch.all(P_sum_product.unsqueeze(-1) >= combined_cumsums + combined_values) + + B, delta_P = compute_beta_prods(P_sum_product, combined_values) + + assert torch.all(combined_values + delta_P > 0) + + + # reorder combined_cumsums from smallest to largest, which we'll require + # when interpolating the "skipped regions" into the random numbers. + combined_cumsums, reorder_indexes = torch.sort(combined_cumsums, dim=-1) + # also reorder delta_P [so that delta_P and combined_cumsums are reordered + # in the same way] + delta_P = torch.gather(delta_P, dim=-1, index=reorder_indexes) + + print("combined_cumsums, reordered, = ", combined_cumsums) + print("delta_P, reordered, = ", delta_P) + + # delta_P_exclusive, of shape (*, K), is the exclusive cumulative sum of + # delta_P, containing negative values. + delta_P_cumsum = torch.cumsum(delta_P, dim=-1) + delta_P_exclusive = delta_P_cumsum - delta_P + print("delta_P_exclusive = ", delta_P_exclusive) + + # combined_cumsums_mod is combined_cumsums modified by adding the product + # of previous delta_P's (which will be negative). This compensates for + # the fact that the random numbers in "sampled_values" are in a compressed + # space where we "skip over" regions of size -delta_P. + # + # These are the cutoffs for subtracting the delta_P's + # from sampled_values + combined_cumsums_mod = combined_cumsums + delta_P_exclusive + print("combined_cumsums_mod = ", combined_cumsums_mod) + + + # CAUTION: if the product of sums is too large, this rand_values + # will not be sufficiently + # random!! We need to leave some headroom. + # rand_values are random in {0, 1, ..., B-1} + rand = torch.randint((2**63 - 1), B.shape) % B + # rand, rand + B, rand + 2B, ...., rand + (K-1)B + samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K) + print("rand = ", rand) + print("sampled = ", samples) + + shifted_samples = compute_shifted_samples(combined_cumsums_mod, + delta_P, + samples) + print("shifted_samples = ", shifted_samples) + + check_shifted_samples(combined_cumsums, + delta_P, + shifted_samples, + P_sum_product) + + indexes = get_indexes_for_samples(P, P_cumsum, + P_cumsum_exclusive, + shifted_samples) + + weights = get_weights_for_samples(P, P_sum_product, B, indexes, + dtype=torch.float32) + print("weights = ", weights) + +def _test_sample_combined(): + for N in [2, 3]: + K = 4 + M = 8 + + p = torch.randn(2, N, M).log_softmax(dim=-1) + + print("N = ", N, ", test_combined2: p = ", p.exp()) + weights, indexes = sample_combined_forward(p, K, True) + print("test_combined2: p = ", p.exp()) + print("weights = ", weights) + print("indexes = ", indexes) + + print("test_combined2: p(2nd time) = ", p.exp()) + p = p.detach() + p.requires_grad = True + weights, indexes = sample_combined(p, K, True) + print("weights2 = ", weights) + print("indexes2 = ", indexes) + + weights.sum().backward() + print("p grad = ", p.grad) + + +def _test_sample_combined_mean(): + for N in [2, 3]: + K = 4 + M = 8 + + p = torch.randn(2, N, M).log_softmax(dim=-1) + + avg_p = torch.zeros_like(p) + num_samples = 1000 + for _ in range(num_samples): + + # weights: (B, K) + # indexes: (B, K, N) + weights, indexes = sample_combined_forward(p, K, True) + + sampled_p = torch.zeros_like(p) + weights_expanded = weights.unsqueeze(-2).expand(*weights.shape[:-1], N, K) + sampled_p.scatter_add_(dim=-1, index=indexes.transpose(-2, -1), + src=weights_expanded) + avg_p += sampled_p * (1.0/num_samples) + print("sample_combined_mean(): N = ", N, ", p = ", p.exp()) + print("avg_p = ", avg_p) + +def _test_knowledge_base_lookup(): + K = 16 + N = 2 + M = 128 + D = 256 + E = 384 + + knowledge_base: nn.Parameter = create_knowledge_base(M, N, D) + m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base) + + B = 30 + T = 4 + x = torch.randn(B, T, E) + x.requires_grad = True + y = m(x) + assert y.shape == x.shape + y.sum().backward() # make sure backward doesn't crash.. + print("y = ", y) + print("x.grad = ", x.grad) + print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) + +if __name__ == '__main__': + _test_sample_combined() + _test_sample_combined_mean() + _test_combined() + _test_compute_beta() + _test_soft_sample() + _test_knowledge_base_lookup() + #test_normalizer()