diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 81d7708f9..f803ee9b6 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -440,8 +440,19 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + + self.in_proj_floor_scale = 10.0 # so it learns fast enough.. + with torch.no_grad(): + in_proj_floor = torch.Tensor(3 * embed_dim) + # key and query get a floor value quite close to zero. + in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale + # value gets very low floor, may be close to having no effectc. + in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale + self.in_proj_floor = nn.Parameter(in_proj_floor) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d @@ -526,6 +537,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale ) def rel_shift(self, x: Tensor) -> Tensor: @@ -570,6 +582,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + in_proj_floor: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -629,9 +642,12 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( + _qkv = nn.functional.linear( query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + ) + if in_proj_floor is not None: + _qkv = torch.maximum(_qkv, in_proj_floor) + q, k, v = _qkv.chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -643,6 +659,10 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + q = torch.maximum(q, _f) + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -650,7 +670,11 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + _kv = nn.functional.linear(key, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + _kv = torch.maximum(_kv, _f) + k, v = _kv.chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -661,6 +685,10 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + q = torch.maximum(q, _f) + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -670,6 +698,9 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] k = nn.functional.linear(key, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + k = torch.maximum(k, _f) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -679,6 +710,10 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:] v = nn.functional.linear(value, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + v = torch.maximum(v, _f) + if attn_mask is not None: assert ( @@ -918,3 +953,13 @@ class Swish(torch.nn.Module): def identity(x): return x + + +if __name__ == '__main__': + feature_dim = 50 + c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c(torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64)) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index c2c6552a9..003b03a2e 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -82,6 +82,7 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ + y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1)