Modify ActivationBalancer for speed (#612)

* add a probability to apply ActivationBalancer

* minor fix

* minor fix
This commit is contained in:
Zengwei Yao 2022-10-13 15:14:28 +08:00 committed by GitHub
parent 1c07d2fb37
commit aa58c2ee02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 5 deletions

View File

@ -12,7 +12,6 @@ cd egs/librispeech/ASR
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
git lfs install
git clone $repo
log "Downloading pre-trained model from $repo_url"
git clone $repo_url

View File

@ -932,7 +932,7 @@ class RelPositionMultiheadAttention(nn.Module):
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
@ -1059,7 +1059,7 @@ class RelPositionMultiheadAttention(nn.Module):
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:

View File

@ -16,6 +16,7 @@
import collections
import random
from itertools import repeat
from typing import Optional, Tuple
@ -636,6 +637,7 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this.
balance_prob: the probability to apply the ActivationBalancer.
"""
def __init__(
@ -646,6 +648,7 @@ class ActivationBalancer(torch.nn.Module):
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0,
balance_prob: float = 0.25,
):
super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim
@ -654,9 +657,11 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
assert 0 < balance_prob <= 1, balance_prob
self.balance_prob = balance_prob
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or is_jit_tracing():
if random.random() >= self.balance_prob:
return x
else:
return ActivationBalancerFunction.apply(
@ -664,7 +669,7 @@ class ActivationBalancer(torch.nn.Module):
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.max_factor / self.balance_prob,
self.min_abs,
self.max_abs,
)

View File

@ -30,6 +30,7 @@ from typing import List
import torch
import torch.nn as nn
from scaling import (
ActivationBalancer,
BasicNorm,
ScaledConv1d,
ScaledConv2d,
@ -294,6 +295,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = convert_basic_norm(m)
elif isinstance(m, ScaledLSTM):
d[name] = scaled_lstm_to_lstm(m)
elif isinstance(m, ActivationBalancer):
d[name] = nn.Identity()
for k, v in d.items():
if "." in k: