mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
Remove xscale from pos_embedding
This commit is contained in:
parent
6561743d7b
commit
c82db4184a
@ -449,7 +449,7 @@ class ScaledLinear(nn.Linear):
|
|||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||||
@ -485,7 +485,7 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
|
||||||
|
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
@ -527,7 +527,7 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed)
|
||||||
|
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
|
@ -327,7 +327,6 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.xscale = math.sqrt(self.d_model)
|
|
||||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
@ -379,7 +378,6 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
x = x * self.xscale
|
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user