mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fix style
This commit is contained in:
parent
fd261eca3a
commit
3cedbe3678
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
|
@ -19,10 +19,10 @@ import collections
|
||||
from itertools import repeat
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from torch import Tensor, _VF
|
||||
import torch
|
||||
import torch.backends.cudnn.rnn as rnn
|
||||
import torch.nn as nn
|
||||
from torch import _VF, Tensor
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
@ -155,7 +155,7 @@ class BasicNorm(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x**2, dim=self.channel_dim, keepdim=True)
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||
+ self.eps.exp()
|
||||
) ** -0.5
|
||||
return x * scales
|
||||
@ -208,12 +208,12 @@ class ScaledLinear(nn.Linear):
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3**0.5) * std
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
@ -257,12 +257,12 @@ class ScaledConv1d(nn.Conv1d):
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3**0.5) * std
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
@ -326,12 +326,12 @@ class ScaledConv2d(nn.Conv2d):
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3**0.5) * std
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
@ -408,9 +408,9 @@ class ScaledLSTM(nn.LSTM):
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3**0.5) * std
|
||||
a = (3 ** 0.5) * std
|
||||
fan_in = self.input_size
|
||||
scale = fan_in**-0.5
|
||||
scale = fan_in ** -0.5
|
||||
v = scale / std
|
||||
for idx, name in enumerate(self._flat_weights_names):
|
||||
if "weight" in name:
|
||||
@ -864,8 +864,8 @@ def _test_basic_norm():
|
||||
y = m(x)
|
||||
|
||||
assert y.shape == x.shape
|
||||
x_rms = (x**2).mean().sqrt()
|
||||
y_rms = (y**2).mean().sqrt()
|
||||
x_rms = (x ** 2).mean().sqrt()
|
||||
y_rms = (y ** 2).mean().sqrt()
|
||||
print("x rms = ", x_rms)
|
||||
print("y rms = ", y_rms)
|
||||
assert y_rms < x_rms
|
||||
|
Loading…
x
Reference in New Issue
Block a user