fix style
This commit is contained in:
parent
fd261eca3a
commit
3cedbe3678
@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|||||||
@ -19,10 +19,10 @@ import collections
|
|||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from torch import Tensor, _VF
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn.rnn as rnn
|
import torch.backends.cudnn.rnn as rnn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import _VF, Tensor
|
||||||
|
|
||||||
|
|
||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
@ -155,7 +155,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
scales = (
|
scales = (
|
||||||
torch.mean(x**2, dim=self.channel_dim, keepdim=True)
|
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||||
+ self.eps.exp()
|
+ self.eps.exp()
|
||||||
) ** -0.5
|
) ** -0.5
|
||||||
return x * scales
|
return x * scales
|
||||||
@ -208,12 +208,12 @@ class ScaledLinear(nn.Linear):
|
|||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
a = (3**0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
nn.init.constant_(self.bias, 0.0)
|
nn.init.constant_(self.bias, 0.0)
|
||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight_scale += torch.tensor(scale / std).log()
|
||||||
|
|
||||||
@ -257,12 +257,12 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
a = (3**0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
nn.init.constant_(self.bias, 0.0)
|
nn.init.constant_(self.bias, 0.0)
|
||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight_scale += torch.tensor(scale / std).log()
|
||||||
|
|
||||||
@ -326,12 +326,12 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
a = (3**0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
nn.init.constant_(self.bias, 0.0)
|
nn.init.constant_(self.bias, 0.0)
|
||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight_scale += torch.tensor(scale / std).log()
|
||||||
|
|
||||||
@ -408,9 +408,9 @@ class ScaledLSTM(nn.LSTM):
|
|||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
a = (3**0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
fan_in = self.input_size
|
fan_in = self.input_size
|
||||||
scale = fan_in**-0.5
|
scale = fan_in ** -0.5
|
||||||
v = scale / std
|
v = scale / std
|
||||||
for idx, name in enumerate(self._flat_weights_names):
|
for idx, name in enumerate(self._flat_weights_names):
|
||||||
if "weight" in name:
|
if "weight" in name:
|
||||||
@ -864,8 +864,8 @@ def _test_basic_norm():
|
|||||||
y = m(x)
|
y = m(x)
|
||||||
|
|
||||||
assert y.shape == x.shape
|
assert y.shape == x.shape
|
||||||
x_rms = (x**2).mean().sqrt()
|
x_rms = (x ** 2).mean().sqrt()
|
||||||
y_rms = (y**2).mean().sqrt()
|
y_rms = (y ** 2).mean().sqrt()
|
||||||
print("x rms = ", x_rms)
|
print("x rms = ", x_rms)
|
||||||
print("y rms = ", y_rms)
|
print("y rms = ", y_rms)
|
||||||
assert y_rms < x_rms
|
assert y_rms < x_rms
|
||||||
|
|||||||
Reference in New Issue
Block a user