diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index f803ee9b6..c06335905 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -440,19 +440,8 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - - self.in_proj_floor_scale = 10.0 # so it learns fast enough.. - with torch.no_grad(): - in_proj_floor = torch.Tensor(3 * embed_dim) - # key and query get a floor value quite close to zero. - in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale - # value gets very low floor, may be close to having no effectc. - in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale - self.in_proj_floor = nn.Parameter(in_proj_floor) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d @@ -537,7 +526,6 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, - in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale ) def rel_shift(self, x: Tensor) -> Tensor: @@ -582,7 +570,6 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, - in_proj_floor: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -642,12 +629,7 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - _qkv = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ) - if in_proj_floor is not None: - _qkv = torch.maximum(_qkv, in_proj_floor) - q, k, v = _qkv.chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -658,10 +640,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - q = torch.maximum(q, _f) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -670,11 +649,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - _kv = nn.functional.linear(key, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - _kv = torch.maximum(_kv, _f) - k, v = _kv.chunk(2, dim=-1) + k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -684,10 +659,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - q = torch.maximum(q, _f) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -697,10 +669,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - k = torch.maximum(k, _f) + k = nn.functional.linear(key, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -709,10 +678,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - v = torch.maximum(v, _f) + v = nn.functional.linear(value, _w, _b).relu() if attn_mask is not None: