diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 00ec14408..eae046e67 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -464,6 +464,7 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.5 ) @@ -649,9 +650,12 @@ class RelPositionMultiheadAttention(nn.Module): scaling = float(head_dim) ** -0.5 + def linear(x, w, b): + return self.in_balancer(nn.functional.linear(x, w, b)) + if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( + q, k, v = linear( query, in_proj_weight, in_proj_bias ).chunk(3, dim=-1) @@ -664,7 +668,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) + q = linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -673,7 +677,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + k, v = linear(key, _w, _b).chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -683,7 +687,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) + q = linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -692,7 +696,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) + k = linear(key, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -701,7 +705,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) + v = linear(value, _w, _b) if attn_mask is not None: assert (