diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 47f36dcf3..c727da341 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -32,7 +32,6 @@ class Conformer(Transformer): self, num_features: int, num_classes: int, - subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, diff --git a/egs/librispeech/ASR/conformer_lm/transformer.py b/egs/librispeech/ASR/conformer_lm/transformer.py index 707eacd1b..4367808a8 100644 --- a/egs/librispeech/ASR/conformer_lm/transformer.py +++ b/egs/librispeech/ASR/conformer_lm/transformer.py @@ -12,10 +12,9 @@ from torch.nn.utils.rnn import pad_sequence Supervisions = Dict[str, torch.Tensor] -class Transformer(nn.Module): +class MaskedLmConformer(nn.Module): def __init__( self, - num_features: int, num_classes: int, d_model: int = 256, nhead: int = 4, @@ -23,17 +22,13 @@ class Transformer(nn.Module): num_encoder_layers: int = 12, num_decoder_layers: int = 6, dropout: float = 0.1, - normalize_before: bool = True, + cnn_module_kernel: int = 31, ) -> None: """ Args: - num_features: - The input dimension of the model. num_classes: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. + The input and output dimension of the model (inputs and outputs are + both discrete) d_model: Attention dimension. nhead: @@ -47,76 +42,45 @@ class Transformer(nn.Module): Number of decoder layers. dropout: Dropout in encoder/decoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - """ - super().__init__() + """ + super(MaskedLmConformer, self).__init__() - - self.num_features = num_features self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - # self.encoder_embed converts the input of shape [N, T, num_classes] - # to the shape [N, T//subsampling_factor, d_model]. - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_classes -> d_model - - - #self.encoder_embed = [TODO...] - - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, + # self.embed is the embedding used for both the encoder and decoder. + self.embed_scale = d_model ** 0.5 + self.embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model, + _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale) ) - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None + self.encoder_pos = RelPositionalEncoding(d_model, dropout) - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + encoder_layer = MaskedLmConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, ) + self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers, + norm=nn.LayerNorm(d_model)) if num_decoder_layers > 0: self.decoder_num_class = self.num_classes - self.decoder_embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( + decoder_layer = TransformerDecoderLayerRelPos( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, - normalize_before=normalize_before, ) - if normalize_before: - decoder_norm = nn.LayerNorm(d_model) - else: - decoder_norm = None + # Projects the embedding of `src`, to be added to `memory` + self.src_linear = torch.nn.Linear(d_model, d_model) - self.decoder = nn.TransformerDecoder( + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoderRelPos( decoder_layer=decoder_layer, num_layers=num_decoder_layers, norm=decoder_norm, @@ -126,363 +90,198 @@ class Transformer(nn.Module): d_model, self.decoder_num_class ) - self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) - else: - self.decoder_criterion = None def forward( self, - src_symbols: torch.Tensor, - src_padding_mask: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + masked_src_symbols: torch.Tensor, + key_padding_mask: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: - src_symbols: + masked_src_symbols: The input symbols to be embedded (will actually have query positions masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64. I.e. shape (N, T) - src_padding_mask: + key_padding_mask: Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T), and dtype=torch.bool which has True in positions to be masked in attention layers and convolutions because they represent padding at the ends of sequences. - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) Returns: - Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is [N, T, C] - - Encoder output with shape [T, N, C]. It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is [N, T]. - It is None if `supervision` is None. + Returns (encoded, pos_emb), where: + `encoded` is a Tensor containing the encoded data; it is of shape (N, T, C) + where C is the embedding_dim. + `pos_emb` is a Tensor containing the relative positional encoding, of + shape (1, 2*T-1, C) """ - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) - x = self.ctc_output(encoder_memory) - return x, encoder_memory, memory_key_padding_mask - - def run_encoder( - self, x: torch.Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Run the transformer encoder. - - Args: - x: - The model input. Its shape is [N, T, C]. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute the encoder padding mask, which is used as memory key - padding mask for the decoder. - Returns: - Return a tuple with two tensors: - - The encoder output, with shape [T, N, C] - - encoder padding mask, with shape [N, T]. - The mask is None if `supervisions` is None. - It is used as memory key padding mask in the decoder. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) + x = self.embed(masked_src_symbols) * self.embed_scale # (N, T, C) + x, pos_emb = self.encoder_pos(x) # pos_emb: (1, 2*T-1, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - return x, mask + x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) - def ctc_output(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The output tensor from the transformer encoder. - Its shape is [T, N, C] - - Returns: - Return a tensor that can be used for CTC decoding. - Its shape is [N, T, C] - """ - x = self.encoder_output_layer(x) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) - return x - - def decoder_forward( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape [T, N, C] - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id - - Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. - """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) - - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss + return x, pos_emb def decoder_nll( self, memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, + pos_emb: torch.Tensor, + src_symbols: torch.Tensor, + tgt_symbols: torch.Tensor, + key_padding_mask: torch.Tensor ) -> torch.Tensor: """ Args: memory: - It's the output of the encoder with shape [T, N, C] - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. + The output of the encoder, with shape (T, N, C) + pos_emb: + Relative positional embedding, of shape (1, 2*T-1, C), as + returned from the encoder + src_symbols: + The un-masked src symbols, a LongTensor of shape (N, T). + Can be used to predict the target + only in a left-to-right manner (otherwise it's cheating). + tgt_symbols: + Target symbols, a LongTensor of shape (N, T). + The same as src_symbols, but shifted by one (and also, + without symbol randomization, see randomize_proportion + in dataloader) + key_padding_mask: + A BoolTensor of shape (N, T), with True for positions + that correspond to padding at the end of source and + memory sequences. The same mask is used for self-attention + and cross-attention, since the padding is the same. + Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). + Returns a tensor of shape (N, T), containing the negative + log-probabilities for the target symbols at each position + in the target sequence. """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. + (T, N, C) = memory.shape - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + tgt_mask = generate_square_subsequent_mask(T, memory.device) - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) + src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + src = memory + self.src_linear(src) # (T, N, C) - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - pred_pad = self.decoder( - tgt=tgt, + # This is a little confusing, how "tgt" is set to src. "src" is the + # symbol sequence without masking but with padding and randomization. + # "tgt" is like "src" but shifted by one. + pred = self.decoder( + tgt=src, memory=memory, tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + tgt_key_padding_mask=key_padding_mask, + memory_key_padding_mask=key_padding_mask, + ) # (T, N, C) + + pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred = self.decoder_output_layer(pred) # (N, T, C) + # nll: negative log-likelihood nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, + pred.view(-1, self.decoder_num_class), + tgt_symbols.view(-1), reduction="none", ) - - nll = nll.view(pred_pad.shape[0], -1) - + nll = nll.view(N, T) return nll -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. + + +class TransformerDecoderRelPos(Module): + r"""TransformerDecoderRelPos is a stack of N decoder layers. + This is modified from nn.TransformerDecoder to support relative positional + encoding. Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. + decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) - - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src - - -class TransformerDecoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerDecoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) + >>> pos_enc = torch.rand() + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, x: Tensor, + pos_emb: Tensor, + memory: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + x: the input embedding sequence to the decoder (required): shape = (T, N, C). + Will be an embedding of `src_symbols` in practice + pos_emb: + A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, + representing the relative positional encoding. + memory: the sequence from the last layer of the encoder (required): + shape = (T, N, C) + attn_mask: the mask for the `x` sequence's attention to itself, + of shape (T, T); in practice, will ensure that no + position can attend to later positions. A torch.Tensor with dtype=torch.float + or dtype=torch.bool. + key_padding_mask: the key-padding mask for both the memory and x sequences, + a torch.Tensor with dtype=bool and shape (N, T): true for masked + positions after the ends of sequences. + """ + + for mod in self.layers: + x = mod(x, pos_emb, memory, x_mask=x_mask, + key_padding_mask=key_padding_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoderLayerRelPos(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block; + to use relative positional encoding; and for some changes/simplifications in interface + because both sequences are the same length and have the same mask. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> pos_emb = torch.rand(1, 20*2+1, 512) + >>> out = decoder_layer(tgt, pos_emb, memory) """ def __init__( @@ -492,11 +291,10 @@ class TransformerDecoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, activation: str = "relu", - normalize_before: bool = True, ) -> None: super(TransformerDecoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) @@ -511,7 +309,6 @@ class TransformerDecoderLayer(nn.Module): self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before def __setstate__(self, state): if "activation" not in state: @@ -520,75 +317,57 @@ class TransformerDecoderLayer(nn.Module): def forward( self, - tgt: torch.Tensor, + x: torch.Tensor, + pos_emb: torch.Tensor, memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, + x_mask: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. Args: - tgt: - the sequence to the decoder layer (required). + x + The input embedding, to be added to by the forward function, of shape (T, N, C). + Attention within x will be left-to-right only (causal), thanks to x_mask. + pos_emb: + A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, + containing the relative positional encoding. memory: - the sequence from the last layer of the encoder (required). - tgt_mask: - the mask for the tgt sequence (optional). - memory_mask: - the mask for the memory sequence (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch (optional). - memory_key_padding_mask: - the mask for the memory keys per batch (optional). + the sequence from the last layer of the encoder (required). Shape = (T, N, C) + x_mask: + the mask for the x, to enforce causal (left to right) attention (optional). + Shape == (T, T); may be bool or float. The first T pertains to the output, + the second T to the input. + key_padding_mask: + the key-padding mask to use for both the x and memory sequences. Shep == (N, T); + may be bool (True==masked) or float (to be added to attention scores). - Shape: - tgt: (T, N, E). - memory: (S, N, E). - tgt_mask: (T, T). - memory_mask: (T, S). - tgt_key_padding_mask: (N, T). - memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number + Returns: + Returns 'x plus something', a torch.Tensor with dtype the same as x (e.g. float), + and shape (T, N, C). """ - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - tgt2 = self.self_attn( - tgt, - tgt, - tgt, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, + residual = x + x = self.norm1(x) + self_attn = self.self_attn(x, x, x, + key_padding_mask=key_padding_mask, + need_weights=False + attn_mask=x_mask, )[0] - tgt = residual + self.dropout1(tgt2) - if not self.normalize_before: - tgt = self.norm1(tgt) + x = residual + self.dropout1(self_attn) - residual = tgt - if self.normalize_before: - tgt = self.norm2(tgt) - tgt2 = self.src_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, + residual = x + x = self.norm2(x) + src_attn = self.src_attn(x, memory, memory, + key_padding_mask=key_padding_mask, + need_weights=False, )[0] - tgt = residual + self.dropout2(tgt2) - if not self.normalize_before: - tgt = self.norm2(tgt) + x = residual + self.dropout2(src_attn) - residual = tgt - if self.normalize_before: - tgt = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = residual + self.dropout3(tgt2) - if not self.normalize_before: - tgt = self.norm3(tgt) - return tgt + residual = x + x = self.norm3(x) + ff = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = residual + self.dropout3(ff) + return x def _get_activation_fn(activation: str): @@ -831,62 +610,6 @@ class LabelSmoothingLoss(nn.Module): return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom -def encoder_padding_mask( - max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[torch.Tensor]: - """Make mask tensor containing indexes of padded part. - - TODO:: - This function **assumes** that the model uses - a subsampling factor of 4. We should remove that - assumption later. - - Args: - max_len: - Maximum length of input features. - CAUTION: It is the length after subsampling. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. - """ - if supervisions is None: - return None - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) - - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] - for idx in range(supervision_segments.size(0)): - # Note: TorchScript doesn't allow to unpack tensors as tuples - sequence_idx = supervision_segments[idx, 0].item() - start_frame = supervision_segments[idx, 1].item() - num_frames = supervision_segments[idx, 2].item() - lengths[sequence_idx] = start_frame + num_frames - - lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] - bs = int(len(lengths)) - seq_range = torch.arange(0, max_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) - # Note: TorchScript doesn't implement Tensor.new() - seq_length_expand = torch.tensor( - lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype - ).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - return mask def decoder_padding_mask( @@ -911,7 +634,7 @@ def decoder_padding_mask( return ys_mask -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: +def generate_square_subsequent_mask(sz: int, device: torch.Device) -> torch.Tensor: """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). The mask can be used for masked self-attention. @@ -928,7 +651,7 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor: Returns: A square mask of dimension (sz, sz) """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = (torch.triu(torch.ones(sz, sz, device=torch.Device)) == 1).transpose(0, 1) mask = ( mask.float() .masked_fill(mask == 0, float("-inf")) @@ -975,3 +698,804 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: for utt in token_ids: ans.append(utt + [eos_id]) return ans + + + +class MaskedConvolutionModule(nn.Module): + """ + This is used in the MaskedLmConformerLayer. It is the same as the ConvolutionModule + of theConformer code, but with key_padding_mask supported to make the output independent + of the batching. + + Modified, ultimately, from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct a MaskedConvolutionModule object.""" + super(MaskedConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor]) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (T, N, C) == (#time, batch, channels). + key_padding_mask: if supplied, a Tensor with dtype=torch.Bool and + shape (N, T), with True for positions that correspond to + padding (and should be zeroed in convolutions). + + Returns: + Tensor: Output tensor (T, N, C) + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # Logical-not key_padding_mask, unsqueeze to shape (N, 1, T) and convert + # to float. Then we can just multiply by it when we need to apply + # masking, i.e. prior to the convolution over time. + if key_padding_mask is not None: + x = x * torch.logical_not(key_padding_mask).unsqueeze(1).to(dtype=x.dtype) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) # (time, batch, channel) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + + +class MaskedLmConformerEncoderLayer(nn.Module): + """ + MaskedLmConformerEncoderLayer is made up of self-attn, feedforward and convolution + networks. It's a simplified version of the conformer code we were previously + using, with pre-normalization hard-coded, relative positional encoding, + LayerNorm instead of BatchNorm in the convolution layers, and the key_padding_mask + applied also in the convolution layers. + + See: "Conformer: Convolution-augmented Transformer for Speech Recognition", for + the basic conformer. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = MaskedConvolutionModule(d_model, cnn_module_kernel) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + x: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + attn_mask: the mask for the x sequence's attention to itself (optional); + of shape (T, T) + key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + x: (T, N, C) i.e. (seq_len, batch_size, num_channels) + pos_emb: (N, 2*T-1, C) + attn_mask: (T, T) or (N*num_heads, T, T), of dtype torch.bool or torch.float, where + the 1st S is interpreted as the target sequence (output) and the 2nd as the source + sequence (input). + key_padding_mask: (N, T), of dtype torch.bool + + T is the sequence length, N is the batch size, C is the number of channels. + Return: + Returns x with something added to it, of shape (T, N, C) + """ + + # macaron style feed forward module + residual = x + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x) + ) + + # multi-headed self-attention module + residual = x + x = self.norm_mha(x) + self_attn = self.self_attn(x, x, x, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False + )[0] + x = residual + self.dropout(self_attn) + + # convolution module + residual = x + x = self.norm_conv(x) + + x = residual + self.dropout(self.conv_module(x, key_padding_mask=key_padding_mask)) + + # feed forward module + residual = x + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + + x = self.norm_final(x) + + return x + + +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + +class MaskedLmConformerEncoder(Module): + r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from + torch.nn.TransformerEncoder + + Args: + encoder_layer: an instance of the MaskedLmConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = MaskedLmConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> src, pos_emb = self.encoder_pos(src) + >>> out = conformer_encoder(src, pos_emb) + """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer: nn.Module, num_layers: int, + norm: Optional[nn.Module] = None): + super(MaskedLmConformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + Args + x: input of shape (T, N, C), i.e. (seq_len, batch, channels) + pos_emb: positional embedding tensor of shape (N, 2*T-1, C), + attn_mask (optional, likely not used): mask for self-attention of + x to itself, of shape (T, T) + key_padding_mask (optional): mask of shape (N, T), dtype must be bool. + Returns: + Returns a tensor with the same shape as x, i.e. (T, N, C). + """ + for mod in self.layers: + x = mod( + x + pos_emb, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + ) + + if self.norm is not None: + x = self.norm(x) + + return x + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (1, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: if true, return (output, attn_output_weights); else, (output, None). + + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, C is + the embedding dimension (number of channels). + - key: :math:`(S, N, C)`, where S is the input sequence length. + - value: :math:`(S, N, C)` + - pos_emb: :math:`(N, 2*T-1, C)`. Note: this assumes T == S, which it will be, but + still we use different letters because S relates to the input position, T to the + output posision. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the input sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(T, S)` where T is the output sequence length, S is the input sequence length. + 3D mask :math:`(N*num_heads, T, S)` where N is the batch size, where T is the output sequence length, + S is the input sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Return: + (output, attn_output_weights) if need_weights==True, else (output, None), where: + + - output: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, + C is the embedding/channel dimension. + - attn_output_weights: :math:`(N, T, S)` where N is the batch size, + T is the output sequence length, S is the input sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + #if not self.is_espnet_structure: + # q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + #if not self.is_espnet_structure: + # attn_output_weights = ( + # matrix_ac + matrix_bd + # ) # (batch, head, time1, time2) + #else: + + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None