From f2a45eb38d1605c3400c9e9ef9b9eddd1126163f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Feb 2022 19:02:47 +0800 Subject: [PATCH] Remove learnable offset, use relu instead. See https://github.com/k2-fsa/icefall/pull/199 --- .../ASR/transducer_stateless/conformer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 81d7708f9..d1a28ccd9 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -629,9 +629,11 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = ( + nn.functional.linear(query, in_proj_weight, in_proj_bias) + .relu() + .chunk(3, dim=-1) + ) elif torch.equal(key, value): # encoder-decoder attention @@ -642,7 +644,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -650,7 +652,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -660,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -669,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) + k = nn.functional.linear(key, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -678,7 +680,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) + v = nn.functional.linear(value, _w, _b).relu() if attn_mask is not None: assert (