fix style

This commit is contained in:
yaozengwei 2022-07-17 21:40:29 +08:00
parent fd261eca3a
commit 3cedbe3678
2 changed files with 13 additions and 13 deletions

View File

@ -15,7 +15,7 @@
# limitations under the License.
import copy
from typing import Optional, Tuple
from typing import Tuple
import torch
from encoder_interface import EncoderInterface

View File

@ -19,10 +19,10 @@ import collections
from itertools import repeat
from typing import Optional, Tuple
from torch import Tensor, _VF
import torch
import torch.backends.cudnn.rnn as rnn
import torch.nn as nn
from torch import _VF, Tensor
def _ntuple(n):
@ -155,7 +155,7 @@ class BasicNorm(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x**2, dim=self.channel_dim, keepdim=True)
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
) ** -0.5
return x * scales
@ -208,12 +208,12 @@ class ScaledLinear(nn.Linear):
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3**0.5) * std
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in**-0.5 # 1/sqrt(fan_in)
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
@ -257,12 +257,12 @@ class ScaledConv1d(nn.Conv1d):
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3**0.5) * std
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in**-0.5 # 1/sqrt(fan_in)
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
@ -326,12 +326,12 @@ class ScaledConv2d(nn.Conv2d):
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3**0.5) * std
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in**-0.5 # 1/sqrt(fan_in)
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
@ -408,9 +408,9 @@ class ScaledLSTM(nn.LSTM):
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3**0.5) * std
a = (3 ** 0.5) * std
fan_in = self.input_size
scale = fan_in**-0.5
scale = fan_in ** -0.5
v = scale / std
for idx, name in enumerate(self._flat_weights_names):
if "weight" in name:
@ -864,8 +864,8 @@ def _test_basic_norm():
y = m(x)
assert y.shape == x.shape
x_rms = (x**2).mean().sqrt()
y_rms = (y**2).mean().sqrt()
x_rms = (x ** 2).mean().sqrt()
y_rms = (y ** 2).mean().sqrt()
print("x rms = ", x_rms)
print("y rms = ", y_rms)
assert y_rms < x_rms