Fix formula for smoothing (was applying more smoothing than intended, and in the opposite sense to intended), also revert max_rms from 2.0 to 4.0

This commit is contained in:
Daniel Povey 2022-07-14 06:06:02 +08:00
parent 4785245e5c
commit 7f6fe02db9

View File

@ -953,13 +953,20 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
batch_size = rms.shape[0]
size = rms.numel() // batch_size
rms = rms ** param_pow
smooth = (smooth0 +
(smooth1 - smooth0) * size / (size + rank))
# want expr to be of the form: smooth = alpha * size / (beta*rank + size)
# from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0.
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
mean = _mean(rms, exclude_dims=[0], keepdim=True)
rms += eps + smooth * mean
new_mean = (eps + (smooth + 1) * mean) # mean of modified rms.
ans = rms / new_mean
if True:
# Apply max_rms
max_rms = 4.0