Modify ActivationBalancer for speed (#612)

* add a probability to apply ActivationBalancer

* minor fix

* minor fix
This commit is contained in:
Zengwei Yao 2022-10-13 15:14:28 +08:00 committed by GitHub
parent 1c07d2fb37
commit aa58c2ee02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 5 deletions

View File

@ -12,7 +12,6 @@ cd egs/librispeech/ASR
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
git lfs install git lfs install
git clone $repo
log "Downloading pre-trained model from $repo_url" log "Downloading pre-trained model from $repo_url"
git clone $repo_url git clone $repo_url

View File

@ -932,7 +932,7 @@ class RelPositionMultiheadAttention(nn.Module):
value: Tensor, value: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = False,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
left_context: int = 0, left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
@ -1059,7 +1059,7 @@ class RelPositionMultiheadAttention(nn.Module):
out_proj_bias: Tensor, out_proj_bias: Tensor,
training: bool = True, training: bool = True,
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = False,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
left_context: int = 0, left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:

View File

@ -16,6 +16,7 @@
import collections import collections
import random
from itertools import repeat from itertools import repeat
from typing import Optional, Tuple from typing import Optional, Tuple
@ -636,6 +637,7 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value per channel, which max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent we allow, before we start to modify the derivatives to prevent
this. this.
balance_prob: the probability to apply the ActivationBalancer.
""" """
def __init__( def __init__(
@ -646,6 +648,7 @@ class ActivationBalancer(torch.nn.Module):
max_factor: float = 0.01, max_factor: float = 0.01,
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 100.0, max_abs: float = 100.0,
balance_prob: float = 0.25,
): ):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim self.channel_dim = channel_dim
@ -654,9 +657,11 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor = max_factor self.max_factor = max_factor
self.min_abs = min_abs self.min_abs = min_abs
self.max_abs = max_abs self.max_abs = max_abs
assert 0 < balance_prob <= 1, balance_prob
self.balance_prob = balance_prob
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or is_jit_tracing(): if random.random() >= self.balance_prob:
return x return x
else: else:
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
@ -664,7 +669,7 @@ class ActivationBalancer(torch.nn.Module):
self.channel_dim, self.channel_dim,
self.min_positive, self.min_positive,
self.max_positive, self.max_positive,
self.max_factor, self.max_factor / self.balance_prob,
self.min_abs, self.min_abs,
self.max_abs, self.max_abs,
) )

View File

@ -30,6 +30,7 @@ from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import ( from scaling import (
ActivationBalancer,
BasicNorm, BasicNorm,
ScaledConv1d, ScaledConv1d,
ScaledConv2d, ScaledConv2d,
@ -294,6 +295,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = convert_basic_norm(m) d[name] = convert_basic_norm(m)
elif isinstance(m, ScaledLSTM): elif isinstance(m, ScaledLSTM):
d[name] = scaled_lstm_to_lstm(m) d[name] = scaled_lstm_to_lstm(m)
elif isinstance(m, ActivationBalancer):
d[name] = nn.Identity()
for k, v in d.items(): for k, v in d.items():
if "." in k: if "." in k: