mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement subtracted momentum [0.33,0.66], and print name in Whiten module.
This commit is contained in:
parent
1a2632d0a2
commit
6c5763fbb3
@ -561,11 +561,13 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
x: Tensor,
|
||||
num_groups: int,
|
||||
whitening_limit: float,
|
||||
grad_scale: float) -> Tensor:
|
||||
grad_scale: float,
|
||||
name: Optional[str]) -> Tensor:
|
||||
ctx.save_for_backward(x)
|
||||
ctx.num_groups = num_groups
|
||||
ctx.whitening_limit = whitening_limit
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.name = name
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
@ -580,7 +582,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
metric = _whitening_metric(x_detached, ctx.num_groups)
|
||||
|
||||
if random.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
||||
logging.info(f"Whitening: name={ctx.name}, num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
||||
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
|
||||
|
||||
(metric - ctx.whitening_limit).relu().backward()
|
||||
@ -588,7 +590,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
|
||||
(penalty_grad.norm() + 1.0e-20))
|
||||
penalty_grad = penalty_grad * scale
|
||||
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
||||
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None
|
||||
|
||||
|
||||
|
||||
@ -630,7 +632,7 @@ class Whiten(nn.Module):
|
||||
(self.min_prob, self.max_prob) = prob
|
||||
assert 0 < self.min_prob < self.max_prob <= 1
|
||||
self.prob = self.max_prob
|
||||
|
||||
self.name = None # will be set in training loop
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
def forward(self,
|
||||
@ -666,7 +668,8 @@ class Whiten(nn.Module):
|
||||
return WhiteningPenaltyFunction.apply(x,
|
||||
self.num_groups,
|
||||
self.whitening_limit,
|
||||
self.grad_scale)
|
||||
self.grad_scale,
|
||||
self.name)
|
||||
|
||||
|
||||
class WithLoss(torch.autograd.Function):
|
||||
|
||||
@ -98,8 +98,8 @@ def set_batch_count(
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'batch_count'):
|
||||
module.batch_count = batch_count
|
||||
if hasattr(module, 'name'):
|
||||
module.name = name
|
||||
if hasattr(module, 'name'):
|
||||
module.name = name
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
@ -464,6 +464,12 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
src_orig = src
|
||||
|
||||
momentum_alpha = 0.66
|
||||
# the -0.5 below is "how strong" to make the negative momentum.
|
||||
momentum_rate = -0.5 * (1.0 / (1 - momentum_alpha))
|
||||
momentum = 0.0
|
||||
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
|
||||
# multi-headed self-attention module
|
||||
@ -478,30 +484,40 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
def add_to_src(src, momentum, x):
|
||||
src = src + x + momentum_rate * momentum
|
||||
momentum = (momentum * momentum_alpha) + x
|
||||
return src, momentum
|
||||
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
attn_weights[0:1])
|
||||
src, momentum = add_to_src(src, momentum,
|
||||
self.nonlin_attention_module(src, attn_weights[0:1]))
|
||||
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
src, momentum = add_to_src(src, momentum,
|
||||
self.feed_forward1(src))
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.attention_squeeze1(src, attn_weights[1:2])
|
||||
src, momentum = add_to_src(src, momentum,
|
||||
self.attention_squeeze1(src, attn_weights[1:2]))
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn(
|
||||
src, attn_weights)
|
||||
src, momentum = add_to_src(src, momentum, self.self_attn(
|
||||
src, attn_weights))
|
||||
|
||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
src, momentum = add_to_src(src, momentum,
|
||||
self.conv_module(src, src_key_padding_mask=src_key_padding_mask))
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
src, momentum = add_to_src(src, momentum,
|
||||
self.feed_forward2(src))
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.attention_squeeze2(src, attn_weights[2:3])
|
||||
|
||||
src, momentum = add_to_src(src, momentum,
|
||||
self.attention_squeeze2(src, attn_weights[2:3]))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user