mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
1187 lines
52 KiB
Python
1187 lines
52 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py,
|
|
# its git history is there.
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch import nn
|
|
from typing import Tuple, Optional
|
|
from scaling import ScaledLinear
|
|
import random
|
|
|
|
# The main export of this file is the function sample_combined().
|
|
|
|
# This file is not part of the implementation; it exists to test that the
|
|
# algorithms described in NOTES.md are correct.
|
|
|
|
|
|
|
|
def compute_k_largest(X, K):
|
|
"""
|
|
Returns, for each row of X, the values and indexes of the K largest elements,
|
|
sorted from largest to smallest.
|
|
Args:
|
|
X: Tensor of any type, of shape (*, M) with M > K
|
|
K: an integer with 0 < K <= M
|
|
Returns (values, indexes), with
|
|
values: K most positive values of each row of X, shape (*, K)
|
|
indexes: indexes [on last axis] of K most positive values of each row of X,
|
|
shape (*, K)
|
|
"""
|
|
values, indexes = torch.sort(X, dim=-1, descending=True)
|
|
return values[...,:K], indexes[...,:K]
|
|
|
|
def get_combined_cumsums(P,
|
|
P_cumsum_exclusive_scaled,
|
|
combined_indexes):
|
|
"""
|
|
This is a function called while sampling from a distribution that's a product of
|
|
N categorical distributions each of size M.
|
|
|
|
Args:
|
|
P: Tensor of int64 of shape (*, N, M), containing the individual integerized
|
|
probabilities of classes.
|
|
P_cumsum_exclusive_scaled: scaled exclusive-sum version of P_cumsum which is cumulative
|
|
sum of P along M dimension, equal to
|
|
(P_cumsum - P) * prod_prev_totals
|
|
where prod_prev_totals is the product of the largest, final elements of P_cumsum
|
|
over previous n indexes)
|
|
combined_indexes: A tensor of int64 of shape (*, K, N), containing the top-K combinations
|
|
of indexes in {0,1,..,M-1} that have the most probability mass, from greatest to least.
|
|
We are interested in the (exclusive) cumulative sum at these points, i.e. for each index
|
|
in `combined_indexes` we are interested in the sum of all prior items.
|
|
|
|
Returns:
|
|
Returns a Tensor of int64 of shape (*, K), returning the cumulative sum of all
|
|
combinations of indexes preceding each of the ones in 'combined_indexes'.
|
|
|
|
We assign probability mass to combinations of indexes over the N axes, by
|
|
multipliying these integerized probabilities, and we're interested in the cumulative
|
|
sum of these products, assuming that index 0 varies fastest. So the (inclusive) cumsum of the
|
|
combination with indexes m0, m1, m2, would have a value given by:
|
|
P_cumsum[..., 0, m0] + P_cumsum[..., 1, m1] * sum0 + P_cumsum[..., 2, m2] * sum0 * sum1
|
|
where sum0 is the total sum from cumsum0 (its last element), and so on.
|
|
"""
|
|
M = P.shape[-1]
|
|
N = P.shape[-2]
|
|
K = combined_indexes.shape[-2]
|
|
assert combined_indexes.shape[-1] == N
|
|
assert combined_indexes.shape[:-2] == P.shape[:-2]
|
|
|
|
# ans: shape (*, K)
|
|
ans = torch.zeros(*combined_indexes.shape[:-1], dtype=P.dtype, device=P.device)
|
|
|
|
# P_cumsum_selected_scaled, of shape (*, N, K), contains the individual looked-up
|
|
# exclusive-cumulative-sum values, i.e. the cumulative sum within the
|
|
# individual softmax/distribution, of all preceding items;
|
|
# these are pre-scaled by the product of total sum [P_sum] over previous
|
|
# n indexes.
|
|
P_cumsum_selected_scaled = P_cumsum_exclusive_scaled.gather(dim=-1, index=combined_indexes.transpose(-2, -1))
|
|
|
|
# P_selected, of shape (*, N, K) contains the individual probability values
|
|
# [corresponding to the indexes we want for the cumulative sum]
|
|
P_selected = P.gather(dim=-1, index=combined_indexes.transpose(-2, -1))
|
|
|
|
P_selected_cumprod = torch.cumprod(P_selected, dim=-2)
|
|
# P_selected_laterprod, of shape (*, N, K), contains the sum of
|
|
# P values for *later* n.
|
|
P_selected_laterprod = P_selected_cumprod[...,N-1:N,:] // P_selected_cumprod
|
|
|
|
|
|
# answer is sum over the N dimension, multipliying the
|
|
# indexes for n>0 by P_prev_sum_product, i.e. the product over previous
|
|
# sums. [Earlier indexes are considered to vary fastest, this was easiest
|
|
# to implement.]
|
|
# Shape: (*, K)
|
|
ans = (P_cumsum_selected_scaled * P_selected_laterprod).sum(dim=-2)
|
|
return ans
|
|
|
|
|
|
def compute_products(values, indexes):
|
|
"""
|
|
This is intended to be called on the outputs of compute_k_largest(). It computes the
|
|
products of different combinations of `values`, as follows:
|
|
|
|
values: Tensor of shape (*, N, K)
|
|
indexes: Tensor of shape (*, N, K)
|
|
|
|
The K refers to the K-best, e.g. K=4, which will have been computed by
|
|
compute_k_largest. `values` contains the K largest elements per row, of a source
|
|
tensor. We are computing all products of these, over the N axis.
|
|
|
|
Returns: (values, indexes), where:
|
|
prod_values: Tensor of shape (*, K**N) containing the products of elements of `values`,
|
|
treating the dimensions in `*` as batch dimensions and taking products
|
|
along the N axis.
|
|
prod_indexes: Tensor of shape (*, K**N, N) containing the indexes of the original
|
|
elements that we took products of.
|
|
"""
|
|
assert values.shape == indexes.shape
|
|
K = values.shape[-1]
|
|
N = values.shape[-2]
|
|
|
|
# assume (*) == (B,) and N==3 for example shapes.
|
|
# e.g. (B, 1, 1, 1)
|
|
unit_shape = list(values.shape[:-2]) + ([1] * N)
|
|
# e.g. (B, K, K, K)
|
|
full_shape = list(values.shape[:-2]) + ([K] * N)
|
|
# e.g. (B, K, K, K, N)
|
|
indexes_shape = list(values.shape[:-2]) + ([K] * N) + [N]
|
|
|
|
prod_values = 1
|
|
prod_indexes = torch.empty(*indexes_shape, dtype=indexes.dtype,
|
|
device=indexes.device)
|
|
|
|
|
|
for n in range(N):
|
|
shape = list(unit_shape) # copy it
|
|
shape[-N + n] = K # e.g. if n==1, shape might be (B, K, 1, 1)
|
|
this_values = values.select(dim=-2, index=n).reshape(shape)
|
|
this_src_indexes = indexes.select(dim=-2, index=n).reshape(shape)
|
|
this_dest_indexes = prod_indexes.select(dim=-1, index=n) # e.g. (B, K, K, K)
|
|
|
|
this_dest_indexes[:] = this_src_indexes # will broadcast
|
|
prod_values = prod_values * this_values # will broadcast
|
|
|
|
|
|
values_shape = list(values.shape[:-2]) + [K**N]
|
|
indexes_shape = values_shape + [N]
|
|
return prod_values.reshape(values_shape), prod_indexes.reshape(indexes_shape)
|
|
|
|
|
|
|
|
def compute_beta(P, K):
|
|
"""
|
|
See: ComputeBeta function [practical version] in NOTES.md.
|
|
Args:
|
|
P: a tensor of shape (*, M), in practice containing integers in {1,2,..2**31+1},
|
|
but can be any integers >0 as far as this function is concerned, provided the
|
|
cumsum does not overflow.
|
|
K: an integer 0 < K < M
|
|
Returns a tensor of integers B of shape (*, 1) such that:
|
|
sum(min(P, B)) == K*B
|
|
[It will subtract a number in {0,1,..K-1} from one element of each row of P
|
|
to make this sum exact.]
|
|
"""
|
|
M = P.shape[-1]
|
|
R, R_indexes = torch.sort(P, dim=-1) # (*, M)
|
|
Q = torch.cumsum(R, dim=-1)
|
|
# Reference pseudocode was:
|
|
#for k in 0,1,...K-1, in any order:
|
|
# # B_k is the value of B if k indexes take the l.h.s. of the "min" expression in min(B, P)
|
|
# B_k = (Q[M-1-i] + K - k - 1) / (K - k) # the "+ K - k - 1" is to ensure we round up
|
|
# if R[M-1-k] >= B_k and P[I-2-k] <= B_k:
|
|
# return B_k
|
|
|
|
temp = torch.arange(K+1, dtype=R.dtype, device=R.device)
|
|
# Kk, of shape (K,), contains [1, 2, ..., K], representing K-k for k = [K-1, K-2, ..., 0]
|
|
Kk = temp[1:K+1]
|
|
# Kk1 of shape (K,), contains [0, 1, ..., K-1], representing K-k-1 for k = [K-1, K-2, ..., 0]
|
|
Kk1 = temp[0:K]
|
|
|
|
Q_part = Q[...,M-K:M] # represents: Q[...,M-1-k] for k = K-1,K-2,...,1,0
|
|
|
|
B_k = Q_part // Kk # shape (*, K)
|
|
remainder_k = Q_part - (B_k * Kk) # shape (*, K)
|
|
|
|
large_int = (2**32 - 1)
|
|
R_part1 = torch.cat((R[...,M-K+1:M], torch.full((*R.shape[:-1], 1), large_int,
|
|
device=R.device)), dim=-1)
|
|
R_part2 = R[...,M-K:M]
|
|
|
|
# is_ok corresponds to: "(k==0 or R[M-k] > B_k) and R[M-1-k] <= B_k" in NOTES.md
|
|
is_ok = (torch.logical_and(R_part1 > B_k, R_part2 <= B_k)) # shape: (*, K)
|
|
|
|
assert torch.all(torch.max(is_ok, dim=-1)[0] == 1)
|
|
B, indexes = torch.max(B_k * is_ok, dim=-1, keepdim=True) # shape: (*, 1)
|
|
remainder = torch.gather(remainder_k, dim=-1, index=indexes)
|
|
|
|
remainder = torch.max(remainder_k * is_ok, dim=-1, keepdim=True)[0] # shape: (*, 1)
|
|
index = torch.max(R_indexes[...,M-K:M] * is_ok, dim=-1, keepdim=True)[0]
|
|
P_index = torch.gather(R_indexes[...,M-K:M], dim=-1, index=indexes)
|
|
P_val = torch.gather(P, dim=-1, index=P_index)
|
|
P_val -= remainder
|
|
P.scatter_(dim=-1, index=P_index, src=P_val)
|
|
|
|
P_min = torch.minimum(P, B)
|
|
|
|
P_min_sum = P_min.sum(dim=-1, keepdim=True)
|
|
assert torch.all(K * B == P_min_sum)
|
|
return B
|
|
|
|
def compute_beta_prods(Psum, Ptop):
|
|
"""
|
|
Version of compute_beta() with a different interface, which is intended to work with
|
|
products of softmaxes. We are still assuming an integerized representation.
|
|
|
|
Args:
|
|
Psum: Tensor of shape (*,), treated as the batch dimension, which contains,
|
|
as torch.int64, the total integerized probability mass taken as a product
|
|
along all dimension, e.g. for a tensor of shape (*, N, K) containing integerized
|
|
probabilities, we'd sum along the K dimension and take a product along the N
|
|
dimension.
|
|
Ptop: Tensor of shape (*, K), containing the probabilities for the top-K
|
|
possible outputs (each possible output is a combination of N indexes in
|
|
[0..M-1]). The sum of Ptop must be less than Psum.
|
|
|
|
Returns: (B, delta_P)
|
|
beta: Tensor of shape (*) containing integers B satisfying:
|
|
sum(min(P, B)) == K*B (eqn:b1)
|
|
... where conceptually, P is a matrix of shape (*, K**N)
|
|
that we do not materialize.
|
|
What this condition amounts to in terms of args of this function,
|
|
is that:
|
|
Psum + delta_P.sum(-1) = B*K
|
|
[Caution: the exact equality in (eqn:b1) is only true
|
|
once we subtract a small number in [0..K-1] from the next-largest
|
|
element of P that is not >B, to correct for rounding error;
|
|
this is accounted for in delta_P.
|
|
delta_P: of shape (*, K), this contains the change, if any, that we have
|
|
to make to the top-K elements of the distribution before sampling.
|
|
Satisfies delta_P <= 0. This combines two things: the
|
|
differences (min(P[i], B) - P[i]); and the values in [-(K-1)..0]
|
|
that we add to the largest item that's less than P to account
|
|
for rounding effects.
|
|
"""
|
|
K = Ptop.shape[-1]
|
|
assert Psum.shape == Ptop.shape[:-1]
|
|
|
|
Ptop_cum = torch.cumsum(Ptop, dim=-1) # cumsum of Ptop, i.e. inclusive-sum. Shape (*, K)
|
|
|
|
# add zero first element per row, so Ptop_cum_shift[...,0] is all-zeros and
|
|
# Ptop_cum_shift[...,1] contains the top-1. The idea is that
|
|
# Ptop_cum_shift[...,k] contains the sum of the top k items.
|
|
Ptop_cum_shift = torch.cat((torch.zeros(*Ptop.shape[:-1], 1, dtype=Ptop.dtype,
|
|
device=Ptop.device),
|
|
Ptop_cum[...,:K-1]), dim=-1)
|
|
# S1[...,k] contains, for each batch element, the sum of all but the k largest
|
|
# items. It corresponds to s-1 in the math of NOTES.md, see "ComputeBeta function
|
|
# [mathematical version].
|
|
# Shape is (*, K)
|
|
S1 = Psum.unsqueeze(-1) - Ptop_cum_shift
|
|
|
|
temp = torch.arange(K, -1, -1, device=Psum.device) # [K, K-1, ..., 0]
|
|
# Kk, of shape (K,), contains [K, K-1, ..., 1], representing K-k for k = [0, 1, ..., K-1]
|
|
Kk = temp[0:K]
|
|
# Kk1 of shape (K,), contains [K-1, K-2, ..., 0], representing K-k-1 for k = [0, 1, ..., K-1]
|
|
Kk1 = temp[1:K+1]
|
|
|
|
# The following corresponds to:
|
|
# beta = (1 - s_k) / (K-k)
|
|
# in NOTES.md. This is integer division, we are rounding down.
|
|
# B_k[...,k] is the beta value if k values are >= beta.
|
|
B_k = S1 // Kk # shape (*, K)
|
|
remainder_k = S1 - (B_k * Kk) # shape (*, K)
|
|
|
|
large_int = (2**63 - 1)
|
|
# Ptop_shifted is Ptop shifted right with a large value put first, i.e.
|
|
# instead of [top1, top2, top3, top4] we have [inf, top1, top2, top3]
|
|
Ptop_shifted = torch.cat((torch.full((*Ptop.shape[:-1], 1), large_int,
|
|
device=Ptop.device),
|
|
Ptop[...,:K-1]), dim=-1)
|
|
|
|
|
|
# is_ok corresponds to: "(k==0 or R[M-k] > B_k) and R[M-1-k] <= B_k" in NOTES.md
|
|
# It is true only for the "correct" k for each batch element, that corresponds
|
|
# to the number of values greater than B_k.
|
|
is_ok = (torch.logical_and(Ptop_shifted > B_k, Ptop <= B_k)) # shape: (*, K)
|
|
|
|
# `indexes` are the values of k.
|
|
B, indexes = torch.max(B_k * is_ok, dim=-1) # shape: (*,)
|
|
|
|
delta_P = (torch.minimum(Ptop, B.unsqueeze(-1)) - Ptop) - (remainder_k * is_ok)
|
|
|
|
err = Psum + delta_P.sum(dim=-1) - B * K
|
|
assert torch.all(err == 0)
|
|
assert torch.all(torch.sum(is_ok, dim=-1)[0] == 1)
|
|
|
|
return B, delta_P
|
|
|
|
def compute_shifted_samples(combined_cumsums_mod: Tensor,
|
|
delta_P: Tensor,
|
|
samples: Tensor) -> Tensor:
|
|
"""
|
|
Modified randomly sampled values by adding values to correct for "disallowed regions",
|
|
i.e. parts of probability space that we skip because they correspond to a probability
|
|
mass greater than beta [or because they correspond to small padding for roundoff].
|
|
|
|
combined_cumsums_mod: Modified cumulative sums which when they were "combined_cumsums"
|
|
can be thought of as points in probability space, but when they become
|
|
"modified" are reduced to account for "disallowed regions" that
|
|
we cannot sample. The shape is (*, K) where `*` is the batch dimension
|
|
and K is the maximum number of "disallowed regions"
|
|
delta_P: negative values that correspond to the amount of probability mass we
|
|
removed for each "disallowed region", i.e. the size of those
|
|
regions, as a negative number. The shape is (*, K).
|
|
samples: The samples that we have to modify by adding values corresponding to
|
|
the widths of the appropriate disallowed regions. The shape is (*, K);
|
|
but this K is not the "same K"
|
|
Returns: shifted_samples, which will be the same shape as `samples`, but possibly
|
|
with larger values, i.e. shifted_samples >= samples
|
|
"""
|
|
samples = samples.unsqueeze(-1)
|
|
combined_cumsums_mod = combined_cumsums_mod.unsqueeze(-2)
|
|
delta_P = delta_P.unsqueeze(-2)
|
|
|
|
# of shape (*, K, K), is_ge is True if sample k1 is >= combined_cumsum k2,
|
|
# meaning we need to add the corresponding delta_p.
|
|
is_ge = (samples >= combined_cumsums_mod)
|
|
|
|
shifted_samples = samples - (is_ge * delta_P).sum(dim=-1, keepdim=True)
|
|
shifted_samples = shifted_samples.squeeze(-1)
|
|
return shifted_samples
|
|
|
|
def check_shifted_samples(combined_cumsums: Tensor,
|
|
delta_P: Tensor,
|
|
shifted_samples: Tensor,
|
|
prod_cumsum: Tensor):
|
|
"""
|
|
Checks samples as modified by `compute_shifted_samples`: specifically, checks
|
|
that they are not in the "disallowed regions" that we are supposed to skip over.
|
|
|
|
combined_cumsums: Cumulative sums which can be thought of as the start of
|
|
"disallowed regions" in probability space. Shape is (*, K)
|
|
delta_P: the negative of the size of "disallowed regions". Shape is (*, K)
|
|
shifted_samples: The samples as modified by `compute_shifted_samples`. None
|
|
of these should be within the "disallowed regions". Shape is (*, K);
|
|
but note, this K does not have a correpondence with the K in the
|
|
other two args' shapes.
|
|
prod_cumsum: The product of sums/normalizers of the different softmaxes, of
|
|
shape (*,); this can be thought of as the total size of the probability
|
|
space, including "disallowed regions". This is to check that
|
|
`shifted_samples` are less than this value.
|
|
"""
|
|
assert torch.all(torch.logical_and(shifted_samples >= 0,
|
|
shifted_samples < prod_cumsum.unsqueeze(-1)))
|
|
|
|
shifted_samples = shifted_samples.unsqueeze(-1)
|
|
combined_cumsums = combined_cumsums.unsqueeze(-2)
|
|
delta_P = delta_P.unsqueeze(-2)
|
|
|
|
disallowed_regions_start = combined_cumsums
|
|
disallowed_regions_end = combined_cumsums - delta_P # delta_p is <= 0.
|
|
|
|
# in_disallowed_region is of shape (*, K, K)
|
|
in_disallowed_region = torch.logical_and(shifted_samples >= disallowed_regions_start,
|
|
shifted_samples < disallowed_regions_end)
|
|
assert torch.all(torch.logical_not(in_disallowed_region))
|
|
|
|
|
|
|
|
def get_indexes_for_samples(P: Tensor,
|
|
P_cumsum: Tensor,
|
|
P_cumsum_exclusive: Tensor,
|
|
shifted_samples: Tensor) -> Tensor:
|
|
"""
|
|
From K `shifted_samples` which are in the joint probability-space of N softmaxes
|
|
of size M, figure out which sample indexes they correspond to.
|
|
Args:
|
|
P: of shape (*, N, M), the original integerized probabilities we
|
|
are interested in the products over [i.e. over the N dimension],
|
|
e.g. N=2, M=128.
|
|
P_cumsum: Of shape (*, N, M), this is the (inclusive) cumulative sum of
|
|
the original integerized probabilities P. Conceptually, the entire
|
|
probability space is over all possible products, over the N
|
|
dimension, of different choices of m, arranged so that m-indexes for
|
|
the earlier n indexes vary fastest, like [000,100,200,010,110,210, ... ].
|
|
P_cumsum_exclusive: Of shape (*, N, M), the exclusive-sum version of
|
|
P_cumsum, equivalent to P_cumsum - P.
|
|
shifted_samples: Of shape (*, K), contains the random samples we want
|
|
to find indexes for, "shifted" means we have skipped over "disallowed regions"
|
|
corresponding to combinations of indexes that had too much probability mass.
|
|
Will satisfy:
|
|
0 <= shifted_samples < P_cumsum[...,-1].prod(dim=-1, keepdim=True)
|
|
Returns:
|
|
indexes: Of shape (*, K, N), the N-tuples of indexes in {0,1...M-1}
|
|
corresponding to each of the K samples.
|
|
"""
|
|
|
|
# P_sum_cumprod is the cumulative product of the total sum of the original
|
|
# integerized probabilities P, of shape (*, M)
|
|
P_sum_cumprod = torch.cumprod(P_cumsum[...,-1], dim=-1)
|
|
M = P.shape[-1]
|
|
N = P.shape[-2]
|
|
|
|
ans_indexes_shape = list(shifted_samples.shape) + [N] # (*, K, N)
|
|
ans_indexes = torch.empty(*ans_indexes_shape, dtype=P.dtype,
|
|
device=P.device)
|
|
|
|
cur_samples = shifted_samples # (*, K)
|
|
for n in range(N-1, -1, -1): # [N-1, N-2, ..., 0]
|
|
this_samples = cur_samples # (*, K)
|
|
if n > 0:
|
|
# divide by the total product of probs *previous* indexes n,
|
|
# so we can compare directly with P_cumsum.
|
|
this_samples = this_samples // P_sum_cumprod[...,n-1:n]
|
|
# right=True means we find
|
|
# P_cumsum[...,index-1] <= this_samples[...,k] < P_cumsum[...,index],
|
|
# which is what we want, as opposed to ... < ... <= (i.e. swap < and <=)
|
|
idx = ans_indexes[...,n] = torch.searchsorted(P_cumsum[...,n,:], # (*, M)
|
|
this_samples, # (*, K)
|
|
right=True)
|
|
this_P = torch.gather(P[...,n,:], dim=-1, index=idx) # shape: (*, K)
|
|
|
|
if n == 0:
|
|
break
|
|
|
|
# get cumsum corresponding to the indexes we just computed, we need
|
|
# to subtract the start of the region corresponding to this index.
|
|
# need exclusive-sum here..
|
|
cur_cumsum = torch.gather(P_cumsum_exclusive[...,n,:], dim=-1, index=idx)
|
|
# account for the product of previous dims' total sums...
|
|
# TODO: multiply P_cumsum by P_sum_cumprod
|
|
cur_cumsum *= P_sum_cumprod[...,n-1:n]
|
|
# Get the remainder after subtracting the indexes we just worked out,
|
|
# this will be used to get previous indexes, i.e. for lower n.
|
|
remainder = cur_samples - cur_cumsum
|
|
# Also divide by this_P, since all probability masses corresponding
|
|
# to this index we just worked out will be scaled by this amount.
|
|
remainder = remainder // this_P
|
|
cur_samples = remainder
|
|
|
|
return ans_indexes
|
|
|
|
def get_weights_for_samples(P: Tensor,
|
|
P_sum_product: Tensor,
|
|
B: Tensor,
|
|
indexes: Tensor,
|
|
dtype: torch.dtype) -> Tensor:
|
|
"""
|
|
Return output weights for the K samples we selected for each distribution.
|
|
The probability of selecting a particular sample with probability p_i
|
|
is: min(1, p_i/beta), and the output weight for a sample (if we select it)
|
|
will be p_i divided by the probability with which we sampled it,
|
|
i.e. p_i / min(1, p_i/beta) = max(p_i, beta). P and B are integerized
|
|
forms of p and beta, we have to divide by P_sum_product to get
|
|
the actual values.
|
|
|
|
Args:
|
|
P: integerized probabilities for the individual distributions
|
|
in our product-of-distributions, of shape
|
|
(*, N, M), where * is the batch dimension(s), N is the
|
|
number of distributions in the product (e.g. 2 or 3), and
|
|
M is the size of each distribution (e.g. 128).
|
|
P_sum_product: of shape (*,) the result of taking the sum of
|
|
P over the M dimension and then the product over the N
|
|
dimension.
|
|
B: of shape (*,), the integerized value of beta
|
|
(B/P_sum_product == beta). We sample each item with
|
|
probability min(1, prob_of_item/beta), with beta
|
|
chosen such that the sum of those probabilities is
|
|
exactly K
|
|
indexes: the indexes of the chosen samples, of
|
|
shape (*, K, N). K is the number of samples;
|
|
each sample is an N-tuple of indexes.
|
|
dtype: the desired data-type of the returned probabilities.
|
|
Returns:
|
|
Returns the probabilities for the chosen indexes, of
|
|
shape (*, K); these will sum to one along the K axis.
|
|
"""
|
|
if dtype == torch.float16:
|
|
return get_weights_for_samples(P, P_sum_product,
|
|
B, indexes, torch.float32).to(dtype)
|
|
assert dtype in [torch.float32, torch.float64]
|
|
|
|
# probs: of shape (*, N, K), the integer probabilities for
|
|
# the individual distributions
|
|
probs = torch.gather(P, dim=-1, index=indexes.transpose(-2, -1))
|
|
|
|
# multiply probs across the N axis to get products of shape (*, K)
|
|
probs = probs.prod(dim=-2)
|
|
|
|
# P_sum_product: (*,)
|
|
P_sum_product = P_sum_product.to(dtype=dtype)
|
|
# beta: (*,)
|
|
beta = B.to(dtype=dtype) / P_sum_product
|
|
p = probs.to(dtype=dtype) / P_sum_product.unsqueeze(-1)
|
|
# ans: shape (*, K)
|
|
ans = torch.maximum(p, beta.unsqueeze(-1))
|
|
# ans_sum: shape (*,)
|
|
ans_sum = ans.sum(dim=-1)
|
|
assert torch.all((ans_sum - 1.0).abs() < 0.01)
|
|
return ans
|
|
|
|
|
|
_max_bits = 54 # used in sample_combined_forward and sample_combined_backward,
|
|
# see comment in sample_combined_forward.
|
|
|
|
def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]:
|
|
"""
|
|
Sample from a distribution that is the product of softmaxes. We will sample
|
|
K *distinct* samples. This entails using sampling weights of the form min(1, p/beta)
|
|
for a computed beta.
|
|
Args:
|
|
p: A Tensor of shape (*, N, M): either normalized log-probs (if input_is_log==False),
|
|
or normalized probabilities; normalized along the M axis. M must be
|
|
a power of 2, and N must be in [1,2,3,4].
|
|
K: An integer, the number of samples required, with 0 < K < N
|
|
input_is_log: True if p represents normalized log-probs, False if it represents
|
|
probabilities.
|
|
|
|
Returns: (indexes, weights)
|
|
indexes: of shape (*, K, N), for each of K samples from a distribution it contains
|
|
an N-tuple of indexes saying which combination of indexes from the
|
|
component distributions were sampled.
|
|
weights: of shape (*, K), gives the weight associated with each sample,
|
|
which will equal max(p, beta) for a beta specific to the batch element,
|
|
i.e. to the product of the distributions (0 < beta <= 1/K). The
|
|
weights will sum to 1 along the K axis.
|
|
"""
|
|
p = p.detach() # call sample_combined() if you need derivatives.
|
|
N = p.shape[-2]
|
|
M = p.shape[-1]
|
|
assert M & (M-1) == 0 # required for the random reordering to work (see
|
|
# rand_perm), this ensures odd numbers would be
|
|
# coprime to M.
|
|
dtype = p.dtype
|
|
assert N > 0 and N <= 4
|
|
# allocating 54 bits for the product of distributions means that, for instance,
|
|
# with 3 distributions we can have 18 bits per distribution. The reason
|
|
# we don't go closer to 64, is that to choose random numbers we
|
|
# do: `b = torch.randint((2**63 - 1), B.shape) % B`, and for this to actually
|
|
# be uniformly distributed we need 2**63 - 1 to be substantially larger than
|
|
# the total probability mass. However it's not super-critical that this
|
|
# gap be very large because in any case we randomize the order of indexes
|
|
# before the sampling procedure.
|
|
num_bits_per_sample = _max_bits // N
|
|
|
|
if input_is_log:
|
|
p = p.exp()
|
|
|
|
# rand_perm is in {1,3,..M-1}, it is of shape (*, N, 1); we'll
|
|
# use it to pseudo-randomly reorder each distribution.
|
|
rand_perm = torch.randint(M//2, p.shape[:-1] + (1,), device=p.device) * 2 + 1
|
|
# Note: we could implement this more efficiently with a special kernel.
|
|
rand_perm_indexes = (rand_perm * torch.arange(M, device=p.device)) % M
|
|
# reorder the elements of p; we'll correct for the reordering later when
|
|
# we return indexes.
|
|
p = torch.gather(p, dim=-1, index=rand_perm_indexes)
|
|
|
|
|
|
# the + 1 is because we need all elements of P to be nonzero (this will avoid
|
|
# some nasty edge cases)
|
|
P = (p * (2**(num_bits_per_sample)) + 1).to(dtype=torch.int64)
|
|
values, indexes = compute_k_largest(P, K)
|
|
prod_values, prod_indexes = compute_products(values, indexes)
|
|
|
|
# combined_values, combined_indexes: (B, K) these are the top-K
|
|
# most-probable combinations of (integerized_ probabilities and their
|
|
# indexes, from largest to smallest probability
|
|
combined_values, combined_indexes = compute_k_largest(prod_values, K)
|
|
|
|
# let combined_indexes contain the original N-tuples
|
|
combined_indexes_shape = list(combined_indexes.shape) + [N]
|
|
# combined_indexes: (B, K, N)
|
|
combined_indexes = torch.gather(prod_indexes, dim=-2,
|
|
index=combined_indexes.unsqueeze(-1).expand(combined_indexes_shape))
|
|
|
|
P_cumsum = torch.cumsum(P, dim=-1) # (B, N, M)
|
|
P_cumsum_cat = torch.cat((torch.zeros(*P_cumsum.shape[:-1], 1, dtype=P_cumsum.dtype,
|
|
device=P_cumsum.device),
|
|
P_cumsum), dim=-1)
|
|
P_cumsum_exclusive = P_cumsum_cat[...,:-1]
|
|
P_cumsum = P_cumsum_cat[...,1:]
|
|
|
|
# P_sum is the total sum of the individual softmaxes/distributions.
|
|
# Shape: (*, N)
|
|
P_sum = P_cumsum[..., M-1]
|
|
# P_prev_sum_product, of shape (*, N) contains the product of all the P_sum
|
|
# values for the *previous* indexes n, i.e, over n_prev < n. We divide by
|
|
# P_sum to make it an exclusive, not an inclusive, product.
|
|
|
|
# P_sum_product is the inclusive cumulative product of P_sum, multiplied
|
|
# over the N axis.
|
|
# Shape: (B,)
|
|
P_sum_cumprod = torch.cumprod(P_sum, dim=-1)
|
|
# P_prev_sum_cumprod is the exclusive-product versin of P_sum_cumprod, i.e.
|
|
# contains the product over previous elements of P_sum. Shape: (B,)
|
|
P_sum_product = P_sum_cumprod[...,-1]
|
|
P_prev_sum_cumprod = P_sum_cumprod // P_sum
|
|
|
|
|
|
P_cumsum_cat_scaled = P_cumsum_cat * P_prev_sum_cumprod.unsqueeze(-1)
|
|
P_cumsum_exclusive_scaled = P_cumsum_cat_scaled[...,:-1]
|
|
P_cumsum_scaled = P_cumsum_cat_scaled[...,1:]
|
|
|
|
# combined_cumsums: (B, K)
|
|
combined_cumsums = get_combined_cumsums(P,
|
|
P_cumsum_exclusive_scaled,
|
|
combined_indexes)
|
|
|
|
B, delta_P = compute_beta_prods(P_sum_product, combined_values)
|
|
|
|
|
|
# reorder combined_cumsums from smallest to largest, which we'll require
|
|
# when interpolating the "skipped regions" into the random numbers.
|
|
combined_cumsums, reorder_indexes = torch.sort(combined_cumsums, dim=-1)
|
|
# also reorder delta_P [so that delta_P and combined_cumsums are reordered
|
|
# in the same way]
|
|
delta_P = torch.gather(delta_P, dim=-1, index=reorder_indexes)
|
|
|
|
# delta_P_exclusive, of shape (*, K), is the exclusive cumulative sum of
|
|
# delta_P, containing negative values.
|
|
delta_P_cumsum = torch.cumsum(delta_P, dim=-1)
|
|
delta_P_exclusive = delta_P_cumsum - delta_P
|
|
|
|
# combined_cumsums_mod is combined_cumsums modified by adding the product
|
|
# of previous delta_P's (which will be negative). This compensates for
|
|
# the fact that the random numbers in "sampled_values" are in a compressed
|
|
# space where we "skip over" regions of size -delta_P.
|
|
#
|
|
# These are the cutoffs for subtracting the delta_P's
|
|
# from sampled_values
|
|
combined_cumsums_mod = combined_cumsums + delta_P_exclusive
|
|
|
|
|
|
# CAUTION: if the product of sums is too large, this rand_values
|
|
# will not be sufficiently
|
|
# random!! We need to leave some headroom.
|
|
# rand_values are random in {0, 1, ..., B-1}
|
|
rand = torch.randint((2**63 - 1), B.shape, device=B.device) % B
|
|
# rand, rand + B, rand + 2B, ...., rand + (K-1)B
|
|
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device)
|
|
|
|
shifted_samples = compute_shifted_samples(combined_cumsums_mod,
|
|
delta_P,
|
|
samples)
|
|
|
|
# TODO: could remove the next call
|
|
check_shifted_samples(combined_cumsums, delta_P,
|
|
shifted_samples, P_sum_product)
|
|
|
|
indexes = get_indexes_for_samples(P, P_cumsum,
|
|
P_cumsum_exclusive,
|
|
shifted_samples)
|
|
|
|
weights = get_weights_for_samples(P, P_sum_product, B, indexes,
|
|
dtype=p.dtype)
|
|
|
|
indexes = (indexes * rand_perm.transpose(-2, -1)) % M
|
|
|
|
return weights, indexes
|
|
|
|
def sample_combined_backward(p: Tensor, input_is_log: bool, indexes: Tensor,
|
|
weights: Tensor, weights_grad: Tensor) -> Tensor:
|
|
"""
|
|
Backward for sample_combined(); see sample_combined_forward() for detailed docs on
|
|
the forward pass. Notice that we don't use Torch's inbuilt autograd for this;
|
|
that would not give us the answer we want.
|
|
|
|
View the output of the forward pass as a sparse vector q. You can view the
|
|
forward pass as implementing: q = z p, where z is a sparse vector whose
|
|
*expected* value is [1,1,..]. Because the expected value of z does not change
|
|
with p, we treat z as being independent of p, even though actually
|
|
the detailed distribution of z does depend on p. So the backprop in non-log
|
|
space would just be:
|
|
p_grad = z * output_grad
|
|
where z is the sparse vector we multiplied by in the forward pass. Since
|
|
we can express z as just q / p, this becomes:
|
|
p_grad = q / p * output_grad
|
|
where q is the sparse output of the forward pass. In log-space, this is just
|
|
equivalent to log_p_grad = log_output_grad.
|
|
In non-log space, division by p could lead to infinite output if p is zero;
|
|
in the forward pass we smoothed p by adding 2**-(num_bits_per_sample), and
|
|
if you work it out, the backprop rule correcting for this would just become
|
|
p_grad = q / (p + 2**-(num_bits_per_sample) * output_grad
|
|
|
|
Args:
|
|
p: the probabilities as used in the forward pass, of shape (*, N, M)
|
|
input_is_log: if False, p should be probabilities; if True, p should
|
|
be normalized log-probs, e.g. the output of log_softmax.
|
|
weights: the `weights` output of simple_combined_forward, of shape (*, K)
|
|
indexes: the `indexes` output of simple_combined_forward, of shape (*, K, N)
|
|
weights_grad: the loss-function gradient w.r.t the output weights, of shape
|
|
(*, K)
|
|
"""
|
|
K = weights.shape[-1]
|
|
N = indexes.shape[-1]
|
|
|
|
log_p_grad = torch.zeros_like(p) # (*, N, M)
|
|
# log_weights_grad is derivative w.r.t. log(weights).
|
|
log_weights_grad = weights_grad * weights
|
|
# expanded_log_weights_grad: (*, N, K),
|
|
# duplicate along the N dimension
|
|
expanded_log_weights_grad = log_weights_grad.unsqueeze(-2).expand(*weights.shape[:-1],
|
|
N, K)
|
|
log_p_grad.scatter_add_(dim=-1, index=indexes.transpose(-2, -1), src=expanded_log_weights_grad)
|
|
|
|
if not input_is_log:
|
|
if p.dtype == torch.float16:
|
|
raise ValueError("For float16 input you have to use log-space for input probabilities, "
|
|
"require input_is_log=True")
|
|
num_bits_per_sample = _max_bits // N
|
|
# 2**-num_bits_per_sample is very small, so don't worry about renormalizing p.
|
|
# This is just to stop division by zero.
|
|
p_smoothed = p + (2.0**-num_bits_per_sample)
|
|
log_p_grad.divide_(p_smoothed)
|
|
return log_p_grad
|
|
return log_p_grad
|
|
|
|
class SampleCombinedFunction(torch.autograd.Function):
|
|
# please see sample_combined() or sample_combined_forward() or
|
|
# sample_combined_backward() for documentation
|
|
@staticmethod
|
|
def forward(ctx, p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]:
|
|
with torch.no_grad():
|
|
weights, indexes = sample_combined_forward(p, K, input_is_log)
|
|
ctx.save_for_backward(p, indexes, weights)
|
|
ctx.input_is_log = input_is_log
|
|
return weights, indexes
|
|
|
|
@staticmethod
|
|
def backward(ctx, weights_grad: Optional[Tensor], indexes_grad: Optional[Tensor]) -> Tuple[Tensor, None, None]:
|
|
p, indexes, weights = ctx.saved_tensors
|
|
p_grad = sample_combined_backward(p, ctx.input_is_log, indexes,
|
|
weights, weights_grad)
|
|
return p_grad, None, None
|
|
|
|
|
|
def sample_combined(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]:
|
|
"""
|
|
Sample from a distribution that is the product of softmaxes. We will sample
|
|
K *distinct* samples. This entails using sampling weights of the form min(1, p/beta)
|
|
for a computed beta.
|
|
Args:
|
|
p: A Tensor of shape (*, N, M): either normalized log-probs (if input_is_log==False),
|
|
or normalized probabilities; normalized along the M axis. M must be
|
|
a power of 2, and N must be in [1,2,3,4].
|
|
K: An integer, the number of samples required, with 0 < K < N
|
|
input_is_log: True if p represents normalized log-probs, False if it represents
|
|
probabilities.
|
|
|
|
Returns: (weights, indexes)
|
|
weights: of shape (*, K), gives the weight associated with each sample,
|
|
which will equal max(p, beta) for a beta specific to the batch element,
|
|
i.e. to the product of the distributions (0 < beta <= 1/K). The
|
|
weights will sum to 1 along the K axis.
|
|
indexes: of shape (*, K, N), for each of K samples from a distribution it contains
|
|
an N-tuple of indexes saying which combination of indexes from the
|
|
component distributions were sampled.
|
|
"""
|
|
return SampleCombinedFunction.apply(p, K, input_is_log)
|
|
|
|
|
|
|
|
def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor]:
|
|
"""
|
|
Forward function for soft sampling.
|
|
Args:
|
|
p: Tensor of shape (*, M)
|
|
K: number of samples, 1 <= K < M
|
|
input_is_log: if true, p must be probabilities in [0..1] that sum to one;
|
|
if false, p must be logprobs (that sum to one after exp())
|
|
Returns: (indexes, y), where:
|
|
indexes: shape (*, K), a LongTensor containing elements in [0..M-1], distinct
|
|
along the K axis
|
|
y: shape (*, K), a Tensor containing values in [0..1], which sum to 1 along the
|
|
K axis.
|
|
|
|
Search for "def soft_sample" in NOTES.md to understand this.
|
|
"""
|
|
if input_is_log:
|
|
p = p.exp()
|
|
M = p.shape[-1]
|
|
assert M & (M-1) == 0
|
|
two31 = 2 ** 31 # TEMP for testing, should be 2**31
|
|
# to(dtype=this rounds toward 0, which is good enough
|
|
P = (p*two31 + 1).to(dtype=torch.long)
|
|
B = compute_beta(P, K)
|
|
beta = B / two31
|
|
t = torch.randint(M//2, p.shape[:-1] + (1,),
|
|
device=P.device) # shape: *, 1
|
|
s = t * 2 + 1
|
|
#s = torch.ones_like(t)
|
|
|
|
# turns out we don't need inv_s.
|
|
inv_s = (s ** (M//2 - 1)) % M
|
|
assert torch.all((s * inv_s) % M == 1) # if this fails, check that M is a power of 2
|
|
|
|
# R = pseudo-random re-ordering of p.
|
|
R = torch.minimum(torch.gather(P, dim=-1, index=(s * torch.arange(M, device=P.device)) % M),
|
|
B)
|
|
# S = inclusive-sum of R
|
|
S = torch.cumsum(R, dim=-1)
|
|
|
|
# Let b be a random integer drawn uniformly from {0, 1, ..., B-1}.
|
|
b = torch.randint((2**63 - 1), B.shape, device=B.device) % B
|
|
|
|
S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1, device=S.device), S[...,:-1]), dim=-1)
|
|
|
|
k_prev = (S_prev + b) // B
|
|
k_cur = (S + b) // B
|
|
# if S_prev >= b and k_cur > k_prev:.. don't need S_prev >= b because rounded down.
|
|
is_ok = (k_cur > k_prev)
|
|
|
|
# sort so the "false" goes first and the "true" goes in last K indexes.
|
|
values, indices = is_ok.sort(dim=-1)
|
|
i = indices[...,M-K:M]
|
|
i = (i * s) % M # Reverse the pseudo-random reordering
|
|
y = torch.maximum(torch.gather(p, dim=-1, index=i), beta)
|
|
assert torch.all(is_ok.sum(dim=-1) == K)
|
|
assert torch.all((y.sum(dim=-1) - 1.0).abs() < 0.01)
|
|
|
|
|
|
|
|
def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter:
|
|
std = 0.1
|
|
a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of
|
|
# 0.1 from uniform distribution
|
|
ans = nn.Parameter(torch.ones(M ** N, D))
|
|
nn.init.uniform_(ans, -a, a)
|
|
return ans
|
|
|
|
def join_indexes(indexes: Tensor, M: int) -> Tensor:
|
|
"""
|
|
Combines N-tuples of indexes into single indexes that can be used for
|
|
lookup in the knowledge base. Args:
|
|
indexes: tensor of torch.int64 of shape (*, K, N), with elements in
|
|
{0..M-1}
|
|
M: the size of the original softmaxes, is upper bound on elements
|
|
in indexes
|
|
Returns:
|
|
joined_indexes: of shape (*, K), joined_indexes[...,k] equals
|
|
joined_indexes[...,0,k] + joined_indexes[...,1,k]*(M**1) ... + joined_indexes[...,1,k]*(M**(N-1))]
|
|
"""
|
|
N = indexes.shape[-1]
|
|
n_powers = M ** torch.arange(N, device=indexes.device) # [ 1, M, ..., M**(N-1) ]
|
|
return (indexes * n_powers).sum(dim=-1)
|
|
|
|
|
|
def weighted_matrix_lookup(weights: Tensor,
|
|
indexes: Tensor,
|
|
knowledge_base: Tensor) -> Tensor:
|
|
"""
|
|
Weighted combination of specified rows of a matrix.
|
|
weights: Tensor of shape (*, K), can contain any value but probably in [0..1].
|
|
indexes: Tensor of shape (*, K), with elements in [0..C-1]
|
|
knowledge_base: Tensor of shape (C-1, D), whose rows we'll be looking up
|
|
Returns:
|
|
tensor of shape (*, D), containing weighted sums of rows of
|
|
`knowledge_base`
|
|
"""
|
|
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
|
|
D = knowledge_base.shape[-1]
|
|
weights = weights.unsqueeze(-2) # (*, 1, K)
|
|
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
|
|
ans = torch.matmul(weights, lookup) # ans: (*, 1, D)
|
|
ans = ans.squeeze(-2)
|
|
assert list(ans.shape) == list(weights.shape[:-2]) + [D]
|
|
return ans
|
|
|
|
|
|
class WeightedMatrixLookupFunction(torch.autograd.Function):
|
|
"""
|
|
Weighted matrix lookup, memory efficient version that redoes the computation in the
|
|
backward pass... this is not really optimal but the autograd for this operation is
|
|
complicated.
|
|
|
|
See weighted_matrix_lookup() for documentation.
|
|
"""
|
|
@staticmethod
|
|
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
|
|
ctx.save_for_backward(weights.detach(), indexes.detach(),
|
|
knowledge_base.detach())
|
|
return weighted_matrix_lookup(weights, indexes, knowledge_base)
|
|
|
|
@staticmethod
|
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
|
|
weights, indexes, knowledge_base = ctx.saved_tensors
|
|
weights.requires_grad = True
|
|
knowledge_base.requires_grad = True
|
|
with torch.enable_grad():
|
|
ans = weighted_matrix_lookup(weights, indexes, knowledge_base)
|
|
ans.backward(gradient=ans_grad)
|
|
return weights.grad, None, knowledge_base.grad
|
|
|
|
|
|
class KnowledgeBaseLookup(nn.Module):
|
|
"""
|
|
Create knowledge-base lookup module. (The knowledge-base parameter, which is
|
|
large, is shared between these modules).
|
|
Args:
|
|
M: int, softmax size, e.g. in [32..128]
|
|
N: int, number of softmaxes, in [2..3]
|
|
D: int, embedding dimension in knowledge base, e.g. 256
|
|
K: number of samples (affects speed/accuracy tradeoff), e.g. 16.
|
|
embedding_dim: the dimension to project from and to, e.g. the
|
|
d_model of the conformer.
|
|
"""
|
|
def __init__(self, M: int, N: int, D: int,
|
|
K: int, embedding_dim: int,
|
|
knowledge_base: nn.Parameter):
|
|
super(KnowledgeBaseLookup, self).__init__()
|
|
self.knowledge_base = knowledge_base # shared!
|
|
self.in_proj = ScaledLinear(embedding_dim, M * N,
|
|
initial_scale=5.0)
|
|
# initial_scale = 4.0 because the knowlege_base activations are
|
|
# quite small -- if we use our optimizer they'll have stddev <= 0.1.
|
|
self.out_proj = ScaledLinear(D, embedding_dim,
|
|
initial_scale = 10.0)
|
|
self.M = M
|
|
self.N = N
|
|
self.K = K
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""
|
|
Forward function that does knowledge-base lookup.
|
|
Args:
|
|
x: input, of shape (*, E) where E is embedding_dim
|
|
as passed to constructor
|
|
y: output of knowledge-base lookup, of shape (*, E)
|
|
|
|
# TODO: later we can try multiplying by a projection of x or something like that.
|
|
"""
|
|
x = self.in_proj(x) # now (*, M*N)
|
|
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
|
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
|
if random.random() < 0.01:
|
|
entropy = (x * x.exp()).sum(dim=-1).mean()
|
|
print("Entropy = ", entropy)
|
|
weights, indexes, = sample_combined(x, self.K, input_is_log=True)
|
|
indexes = join_indexes(indexes, self.M)
|
|
x = WeightedMatrixLookupFunction.apply(weights, indexes, self.knowledge_base) # now (*, D)
|
|
x = self.out_proj(x) # now (*, self.embedding_dim)
|
|
return x
|
|
|
|
|
|
def _test_compute_beta():
|
|
# use a small M-- 8 here-- because it's more likely to
|
|
# choose k != 0 in compute_beta(), giving a more complete test.
|
|
a = torch.randint(low=1, high=65535, size=(9, 16))
|
|
K = 4
|
|
beta = compute_beta(a, K) # it checks its own answer..
|
|
print("beta = ", beta)
|
|
|
|
|
|
def _test_soft_sample():
|
|
l = 2 * torch.randn(6, 64)
|
|
p = torch.softmax(l, dim=-1)
|
|
soft_sample_forward(p, K=4, input_is_log=False)
|
|
|
|
def _test_combined():
|
|
N = 2
|
|
K = 4
|
|
M = 8
|
|
|
|
P = ((5 * torch.randn(2, N, M)).softmax(dim=-1) * 16 + 1).to(dtype=torch.int64)
|
|
|
|
print("P = ", P)
|
|
values, indexes = compute_k_largest(P, K)
|
|
print("largest values = ", values)
|
|
print("largest indexes = ", indexes)
|
|
prod_values, prod_indexes = compute_products(values, indexes)
|
|
assert prod_values.shape == prod_indexes.shape[:-1]
|
|
print("prod_values = ", prod_values)
|
|
print("prod_indexes = ", prod_indexes)
|
|
|
|
# combined_values, combined_indexes: (B, K) these are the top-K
|
|
# most-probable combinations of (integerized_ probabilities and their
|
|
# indexes, from best to worst.
|
|
combined_values, combined_indexes = compute_k_largest(prod_values, K)
|
|
|
|
combined_indexes_shape = list(combined_indexes.shape) + [N]
|
|
# combined_indexes: (B, K, N)
|
|
combined_indexes = torch.gather(prod_indexes, dim=-2,
|
|
index=combined_indexes.unsqueeze(-1).expand(combined_indexes_shape))
|
|
|
|
print("combined_values = ", combined_values)
|
|
print("combined_indexes = ", combined_indexes)
|
|
|
|
|
|
P_cumsum = torch.cumsum(P, dim=-1) # (B, N, M)
|
|
P_cumsum_cat = torch.cat((torch.zeros(*P_cumsum.shape[:-1], 1, dtype=P_cumsum.dtype,
|
|
device=P_cumsum.device),
|
|
P_cumsum), dim=-1)
|
|
P_cumsum_exclusive = P_cumsum_cat[...,:-1]
|
|
P_cumsum = P_cumsum_cat[...,1:]
|
|
|
|
# P_sum is the total sum of the individual softmaxes/distributions.
|
|
# Shape: (*, N)
|
|
P_sum = P_cumsum[..., M-1]
|
|
# P_prev_sum_product, of shape (*, N) contains the product of all the P_sum
|
|
# values for the *previous* indexes n, i.e, over n_prev < n. We divide by
|
|
# P_sum to make it an exclusive, not an inclusive, product.
|
|
|
|
# P_sum_product is the inclusive cumulative product of P_sum, multiplied
|
|
# over the N axis.
|
|
# Shape: (B,)
|
|
P_sum_cumprod = torch.cumprod(P_sum, dim=-1)
|
|
# P_prev_sum_cumprod is the exclusive-product versin of P_sum_cumprod, i.e.
|
|
# contains the product over previous elements of P_sum. Shape: (B,)
|
|
P_sum_product = P_sum_cumprod[...,-1]
|
|
print("P_sum_product = ", P_sum_product)
|
|
P_prev_sum_cumprod = P_sum_cumprod // P_sum
|
|
|
|
|
|
P_cumsum_cat_scaled = P_cumsum_cat * P_prev_sum_cumprod.unsqueeze(-1)
|
|
P_cumsum_exclusive_scaled = P_cumsum_cat_scaled[...,:-1]
|
|
P_cumsum_scaled = P_cumsum_cat_scaled[...,1:]
|
|
|
|
# combined_cumsums: (B, K)
|
|
combined_cumsums = get_combined_cumsums(P,
|
|
P_cumsum_exclusive_scaled,
|
|
combined_indexes)
|
|
print("combined_cumsums = ", combined_cumsums)
|
|
print("combined_cumsums + combined_values= ", combined_cumsums + combined_values)
|
|
|
|
|
|
assert torch.all(P_sum_product.unsqueeze(-1) > combined_cumsums)
|
|
|
|
assert torch.all(P_sum_product.unsqueeze(-1) >= combined_cumsums + combined_values)
|
|
|
|
B, delta_P = compute_beta_prods(P_sum_product, combined_values)
|
|
|
|
assert torch.all(combined_values + delta_P > 0)
|
|
|
|
|
|
# reorder combined_cumsums from smallest to largest, which we'll require
|
|
# when interpolating the "skipped regions" into the random numbers.
|
|
combined_cumsums, reorder_indexes = torch.sort(combined_cumsums, dim=-1)
|
|
# also reorder delta_P [so that delta_P and combined_cumsums are reordered
|
|
# in the same way]
|
|
delta_P = torch.gather(delta_P, dim=-1, index=reorder_indexes)
|
|
|
|
print("combined_cumsums, reordered, = ", combined_cumsums)
|
|
print("delta_P, reordered, = ", delta_P)
|
|
|
|
# delta_P_exclusive, of shape (*, K), is the exclusive cumulative sum of
|
|
# delta_P, containing negative values.
|
|
delta_P_cumsum = torch.cumsum(delta_P, dim=-1)
|
|
delta_P_exclusive = delta_P_cumsum - delta_P
|
|
print("delta_P_exclusive = ", delta_P_exclusive)
|
|
|
|
# combined_cumsums_mod is combined_cumsums modified by adding the product
|
|
# of previous delta_P's (which will be negative). This compensates for
|
|
# the fact that the random numbers in "sampled_values" are in a compressed
|
|
# space where we "skip over" regions of size -delta_P.
|
|
#
|
|
# These are the cutoffs for subtracting the delta_P's
|
|
# from sampled_values
|
|
combined_cumsums_mod = combined_cumsums + delta_P_exclusive
|
|
print("combined_cumsums_mod = ", combined_cumsums_mod)
|
|
|
|
|
|
# CAUTION: if the product of sums is too large, this rand_values
|
|
# will not be sufficiently
|
|
# random!! We need to leave some headroom.
|
|
# rand_values are random in {0, 1, ..., B-1}
|
|
rand = torch.randint((2**63 - 1), B.shape) % B
|
|
# rand, rand + B, rand + 2B, ...., rand + (K-1)B
|
|
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device)
|
|
print("rand = ", rand)
|
|
print("sampled = ", samples)
|
|
|
|
shifted_samples = compute_shifted_samples(combined_cumsums_mod,
|
|
delta_P,
|
|
samples)
|
|
print("shifted_samples = ", shifted_samples)
|
|
|
|
check_shifted_samples(combined_cumsums,
|
|
delta_P,
|
|
shifted_samples,
|
|
P_sum_product)
|
|
|
|
indexes = get_indexes_for_samples(P, P_cumsum,
|
|
P_cumsum_exclusive,
|
|
shifted_samples)
|
|
|
|
weights = get_weights_for_samples(P, P_sum_product, B, indexes,
|
|
dtype=torch.float32)
|
|
print("weights = ", weights)
|
|
|
|
def _test_sample_combined():
|
|
for N in [2, 3]:
|
|
K = 4
|
|
M = 8
|
|
|
|
p = torch.randn(2, N, M).log_softmax(dim=-1)
|
|
|
|
print("N = ", N, ", test_combined2: p = ", p.exp())
|
|
weights, indexes = sample_combined_forward(p, K, True)
|
|
print("test_combined2: p = ", p.exp())
|
|
print("weights = ", weights)
|
|
print("indexes = ", indexes)
|
|
|
|
print("test_combined2: p(2nd time) = ", p.exp())
|
|
p = p.detach()
|
|
p.requires_grad = True
|
|
weights, indexes = sample_combined(p, K, True)
|
|
print("weights2 = ", weights)
|
|
print("indexes2 = ", indexes)
|
|
|
|
weights.sum().backward()
|
|
print("p grad = ", p.grad)
|
|
|
|
|
|
def _test_sample_combined_mean():
|
|
for N in [2, 3]:
|
|
K = 4
|
|
M = 8
|
|
|
|
p = torch.randn(2, N, M).log_softmax(dim=-1)
|
|
|
|
avg_p = torch.zeros_like(p)
|
|
num_samples = 1000
|
|
for _ in range(num_samples):
|
|
|
|
# weights: (B, K)
|
|
# indexes: (B, K, N)
|
|
weights, indexes = sample_combined_forward(p, K, True)
|
|
|
|
sampled_p = torch.zeros_like(p)
|
|
weights_expanded = weights.unsqueeze(-2).expand(*weights.shape[:-1], N, K)
|
|
sampled_p.scatter_add_(dim=-1, index=indexes.transpose(-2, -1),
|
|
src=weights_expanded)
|
|
avg_p += sampled_p * (1.0/num_samples)
|
|
print("sample_combined_mean(): N = ", N, ", p = ", p.exp())
|
|
print("avg_p = ", avg_p)
|
|
|
|
def _test_knowledge_base_lookup():
|
|
K = 16
|
|
N = 2
|
|
M = 128
|
|
D = 256
|
|
E = 384
|
|
|
|
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
|
|
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
|
|
|
|
B = 30
|
|
T = 4
|
|
x = torch.randn(B, T, E)
|
|
x.requires_grad = True
|
|
y = m(x)
|
|
assert y.shape == x.shape
|
|
y.sum().backward() # make sure backward doesn't crash..
|
|
print("y = ", y)
|
|
print("x.grad = ", x.grad)
|
|
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
|
|
|
|
|
device = torch.device('cuda')
|
|
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(11) ]
|
|
from optim import Eve
|
|
optimizer = Eve(m.parameters(), lr=0.005)
|
|
m = m.to(device)
|
|
|
|
for epoch in range(100):
|
|
for n, (x,y) in enumerate(train_pairs):
|
|
y_out = m(x)
|
|
loss = ((y_out - y)**2).mean()
|
|
if n % 10 == 0:
|
|
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
_test_sample_combined()
|
|
_test_sample_combined_mean()
|
|
_test_combined()
|
|
_test_compute_beta()
|
|
_test_soft_sample()
|
|
_test_knowledge_base_lookup()
|
|
#test_normalizer()
|