mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Modify ActivationBalancer for speed (#612)
* add a probability to apply ActivationBalancer * minor fix * minor fix
This commit is contained in:
parent
1c07d2fb37
commit
aa58c2ee02
@ -12,7 +12,6 @@ cd egs/librispeech/ASR
|
|||||||
|
|
||||||
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
|
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone $repo
|
|
||||||
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
log "Downloading pre-trained model from $repo_url"
|
||||||
git clone $repo_url
|
git clone $repo_url
|
||||||
|
@ -932,7 +932,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
value: Tensor,
|
value: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = False,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
left_context: int = 0,
|
left_context: int = 0,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
@ -1059,7 +1059,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
out_proj_bias: Tensor,
|
out_proj_bias: Tensor,
|
||||||
training: bool = True,
|
training: bool = True,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = False,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
left_context: int = 0,
|
left_context: int = 0,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import random
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@ -636,6 +637,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
max_abs: the maximum average-absolute-value per channel, which
|
max_abs: the maximum average-absolute-value per channel, which
|
||||||
we allow, before we start to modify the derivatives to prevent
|
we allow, before we start to modify the derivatives to prevent
|
||||||
this.
|
this.
|
||||||
|
balance_prob: the probability to apply the ActivationBalancer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -646,6 +648,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
max_factor: float = 0.01,
|
max_factor: float = 0.01,
|
||||||
min_abs: float = 0.2,
|
min_abs: float = 0.2,
|
||||||
max_abs: float = 100.0,
|
max_abs: float = 100.0,
|
||||||
|
balance_prob: float = 0.25,
|
||||||
):
|
):
|
||||||
super(ActivationBalancer, self).__init__()
|
super(ActivationBalancer, self).__init__()
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
@ -654,9 +657,11 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.max_factor = max_factor
|
self.max_factor = max_factor
|
||||||
self.min_abs = min_abs
|
self.min_abs = min_abs
|
||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
assert 0 < balance_prob <= 1, balance_prob
|
||||||
|
self.balance_prob = balance_prob
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting() or is_jit_tracing():
|
if random.random() >= self.balance_prob:
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
@ -664,7 +669,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.channel_dim,
|
self.channel_dim,
|
||||||
self.min_positive,
|
self.min_positive,
|
||||||
self.max_positive,
|
self.max_positive,
|
||||||
self.max_factor,
|
self.max_factor / self.balance_prob,
|
||||||
self.min_abs,
|
self.min_abs,
|
||||||
self.max_abs,
|
self.max_abs,
|
||||||
)
|
)
|
||||||
|
@ -30,6 +30,7 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import (
|
from scaling import (
|
||||||
|
ActivationBalancer,
|
||||||
BasicNorm,
|
BasicNorm,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
@ -294,6 +295,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
|||||||
d[name] = convert_basic_norm(m)
|
d[name] = convert_basic_norm(m)
|
||||||
elif isinstance(m, ScaledLSTM):
|
elif isinstance(m, ScaledLSTM):
|
||||||
d[name] = scaled_lstm_to_lstm(m)
|
d[name] = scaled_lstm_to_lstm(m)
|
||||||
|
elif isinstance(m, ActivationBalancer):
|
||||||
|
d[name] = nn.Identity()
|
||||||
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user