mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Modify ActivationBalancer for speed (#612)
* add a probability to apply ActivationBalancer * minor fix * minor fix
This commit is contained in:
parent
1c07d2fb37
commit
aa58c2ee02
@ -12,7 +12,6 @@ cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
|
||||
git lfs install
|
||||
git clone $repo
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git clone $repo_url
|
||||
|
@ -932,7 +932,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
need_weights: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
@ -1059,7 +1059,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
out_proj_bias: Tensor,
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
need_weights: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
|
||||
import collections
|
||||
import random
|
||||
from itertools import repeat
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -636,6 +637,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
max_abs: the maximum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
balance_prob: the probability to apply the ActivationBalancer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -646,6 +648,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0,
|
||||
balance_prob: float = 0.25,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.channel_dim = channel_dim
|
||||
@ -654,9 +657,11 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
assert 0 < balance_prob <= 1, balance_prob
|
||||
self.balance_prob = balance_prob
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or is_jit_tracing():
|
||||
if random.random() >= self.balance_prob:
|
||||
return x
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
@ -664,7 +669,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.max_factor / self.balance_prob,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
@ -294,6 +295,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||
d[name] = convert_basic_norm(m)
|
||||
elif isinstance(m, ScaledLSTM):
|
||||
d[name] = scaled_lstm_to_lstm(m)
|
||||
elif isinstance(m, ActivationBalancer):
|
||||
d[name] = nn.Identity()
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
|
Loading…
x
Reference in New Issue
Block a user