from local
This commit is contained in:
parent
51377387a5
commit
cb59d0ee5f
Binary file not shown.
@ -26,32 +26,6 @@ from encoder_interface import EncoderInterface
|
|||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|
||||||
class AdapterHook():
|
|
||||||
'''
|
|
||||||
Implementation of the forward hook to track feature statistics and compute a loss on them.
|
|
||||||
Will compute mean and variance, and will use l2 as a loss
|
|
||||||
'''
|
|
||||||
def __init__(self, module):
|
|
||||||
self.hook = module.register_forward_hook(self.hook_fn)
|
|
||||||
|
|
||||||
def hook_fn(self, module, input, output):
|
|
||||||
# hook co compute deepinversion's feature distribution regularization
|
|
||||||
nch = input[0].shape[1]
|
|
||||||
mean = input[0].mean([0, 2, 3])
|
|
||||||
var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
|
|
||||||
|
|
||||||
#forcing mean and variance to match between two distributions
|
|
||||||
#other ways might work better, i.g. KL divergence
|
|
||||||
r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
|
|
||||||
module.running_mean.data - mean, 2)
|
|
||||||
|
|
||||||
self.r_feature = r_feature
|
|
||||||
# must have no output
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.hook.remove()
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||||
"Sequence Transduction with Recurrent Neural Networks"
|
"Sequence Transduction with Recurrent Neural Networks"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user