From 3cedbe367804dac5d30c49b14af00b2d0d49f0f1 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 17 Jul 2022 21:40:29 +0800 Subject: [PATCH] fix style --- .../ASR/lstm_transducer_stateless/lstm.py | 2 +- .../pruned_transducer_stateless2/scaling.py | 24 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index 781176ea9..47b2c7b2b 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -15,7 +15,7 @@ # limitations under the License. import copy -from typing import Optional, Tuple +from typing import Tuple import torch from encoder_interface import EncoderInterface diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 018e2827b..65c71ab2e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -19,10 +19,10 @@ import collections from itertools import repeat from typing import Optional, Tuple -from torch import Tensor, _VF import torch import torch.backends.cudnn.rnn as rnn import torch.nn as nn +from torch import _VF, Tensor def _ntuple(n): @@ -155,7 +155,7 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -208,12 +208,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -257,12 +257,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -326,12 +326,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -408,9 +408,9 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std fan_in = self.input_size - scale = fan_in**-0.5 + scale = fan_in ** -0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -864,8 +864,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms