mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 12:42:20 +00:00
Some updates to tests, still figuring out issues..
This commit is contained in:
parent
461cb7da6d
commit
da3c9c7594
@ -676,6 +676,23 @@ class ReverseGrad(torch.autograd.Function):
|
|||||||
def backward(ctx, x_grad):
|
def backward(ctx, x_grad):
|
||||||
return -x_grad
|
return -x_grad
|
||||||
|
|
||||||
|
|
||||||
|
class DebugGrad(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, name):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
ctx.name = name
|
||||||
|
return x
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, x_grad):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
x_grad_sum = x_grad.sum().to('cpu').item()
|
||||||
|
x_grad_x_sum = (x_grad * x).sum().to('cpu').item()
|
||||||
|
print(f"For {ctx.name}, x_grad_sum = {x_grad_sum}, x_grad_x_sum = {x_grad_x_sum}")
|
||||||
|
return x_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteBottleneck(nn.Module):
|
class DiscreteBottleneck(nn.Module):
|
||||||
"""
|
"""
|
||||||
This layer forces its input through an information bottleneck via
|
This layer forces its input through an information bottleneck via
|
||||||
@ -804,7 +821,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
This is unnecessary if straight_through_scale == 1.0, since in that
|
This is unnecessary if straight_through_scale == 1.0, since in that
|
||||||
case it would not affect the backpropagated derivatives.
|
case it would not affect the backpropagated derivatives.
|
||||||
"""
|
"""
|
||||||
x = self.norm_in(x) * 5
|
x = self.norm_in(x) * 5 # * 5 gives lower entropy.. starts training faster..
|
||||||
x = self.linear1(x)
|
x = self.linear1(x)
|
||||||
x = x + self.class_offsets
|
x = x + self.class_offsets
|
||||||
|
|
||||||
@ -815,7 +832,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
# This is a little wasteful since we already compute the softmax
|
# This is a little wasteful since we already compute the softmax
|
||||||
# inside 'flow_sample'.
|
# inside 'flow_sample'.
|
||||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||||
if random.random() < 0.01:
|
if random.random() < 0.001:
|
||||||
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||||
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
||||||
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
|
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
|
||||||
@ -824,7 +841,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
global_log_softmax = (global_softmax + 1.0e-20).log()
|
global_log_softmax = (global_softmax + 1.0e-20).log()
|
||||||
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
|
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
|
||||||
|
|
||||||
print("Entropy = ", -negentropy, ", averaged negentropy = ", global_negentropy)
|
print("Entropy = ", -negentropy.to('cpu').item(), ", averaged entropy = ", -global_negentropy.to('cpu').item())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -845,6 +862,10 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
|
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
|
||||||
|
|
||||||
embedding = self.norm_out(self.linear2(x))
|
embedding = self.norm_out(self.linear2(x))
|
||||||
|
|
||||||
|
#if random.random() < 0.01:
|
||||||
|
# return (embedding, DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax"))
|
||||||
|
#else:
|
||||||
return (embedding, sampled, softmax)
|
return (embedding, sampled, softmax)
|
||||||
|
|
||||||
|
|
||||||
@ -885,7 +906,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
sampled = ReverseGrad.apply(sampled)
|
sampled = ReverseGrad.apply(sampled)
|
||||||
if softmax is None:
|
if softmax is None:
|
||||||
softmax = sampled
|
softmax = sampled
|
||||||
else:
|
elif reverse_grad:
|
||||||
softmax = ReverseGrad.apply(softmax)
|
softmax = ReverseGrad.apply(softmax)
|
||||||
|
|
||||||
logprobs = self.pred_linear(x)
|
logprobs = self.pred_linear(x)
|
||||||
@ -1825,11 +1846,11 @@ def _test_bidirectional_conformer():
|
|||||||
|
|
||||||
def _test_discrete_bottleneck():
|
def _test_discrete_bottleneck():
|
||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
dim = 128
|
dim = 256
|
||||||
tot_classes = 256
|
tot_classes = 256
|
||||||
num_groups = 1
|
num_groups = 8
|
||||||
interp_prob = 0.8
|
interp_prob = 1.0
|
||||||
straight_through_scale = 1.0 # will change
|
straight_through_scale = 0.0 # will change
|
||||||
need_softmax = True
|
need_softmax = True
|
||||||
|
|
||||||
b = DiscreteBottleneck(dim, tot_classes, num_groups,
|
b = DiscreteBottleneck(dim, tot_classes, num_groups,
|
||||||
@ -1843,29 +1864,42 @@ def _test_discrete_bottleneck():
|
|||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
optim = torch.optim.Adam(params=model.parameters(),
|
optim = torch.optim.Adam(params=model.parameters(),
|
||||||
lr=1.0e-03)
|
lr=3.0e-04)
|
||||||
|
|
||||||
|
|
||||||
|
scale = 0.7 # determines the feature correlation..should be between 0 and 1.
|
||||||
|
#https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2)..
|
||||||
|
# scale corresponds to rho^2, rho being sqrt(scale).
|
||||||
|
mutual_information = dim * -0.5 * math.log(1.0 - scale)
|
||||||
|
print("mutual_information = ", mutual_information)
|
||||||
|
|
||||||
for epoch in range(10):
|
for epoch in range(10):
|
||||||
torch.save(model.state_dict(), f'epoch-{epoch}.pt')
|
torch.save(model.state_dict(), f'epoch-{epoch}.pt')
|
||||||
for i in range(1000):
|
for i in range(2000):
|
||||||
# TODO: also test padding_mask
|
# TODO: also test padding_mask
|
||||||
T = 300
|
T = 300
|
||||||
N = 10
|
N = 10
|
||||||
|
|
||||||
feats = torch.randn(T, N, dim, device=device)
|
feats = torch.randn(T, N, dim, device=device)
|
||||||
|
|
||||||
|
feats2 = (feats * scale ** 0.5) + ((1.0 - scale) ** 0.5 * torch.randn(T, N, dim, device=device))
|
||||||
|
|
||||||
|
#print(f"norm(feats) ={feats.norm()} vs. norm(feats2) = {feats2.norm()}")
|
||||||
|
|
||||||
bn_memory, sampled, softmax = b(feats)
|
bn_memory, sampled, softmax = b(feats)
|
||||||
|
|
||||||
predictor = from_feats_predictor(feats) # Could also use `bn_memory`, perhaps. But using
|
# using feats2 instead of feats will limit the mutual information,
|
||||||
# predictor because it contains everything, will give
|
# to the MI between feats and feats2, which we computed and printed
|
||||||
# us a bound on max information..
|
# above as mutual_information.
|
||||||
|
predictor = from_feats_predictor(feats2)
|
||||||
prob = b.compute_prob(predictor, sampled, softmax)
|
prob = b.compute_prob(predictor, sampled, softmax)
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
sampled_reversed = ReverseGrad.apply(sampled)
|
sampled_reversed = ReverseGrad.apply(sampled)
|
||||||
predictor_reversed = self_predictor(sampled_reversed)
|
predictor_reversed = self_predictor(sampled_reversed)
|
||||||
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
#predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
||||||
predictor_reversed[:-1,:,:]), dim=0)
|
# predictor_reversed[:-1,:,:]), dim=0)
|
||||||
|
predictor_reversed_shifted = predictor_reversed
|
||||||
|
|
||||||
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
||||||
reverse_grad=True)
|
reverse_grad=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user