This commit is contained in:
Daniel Povey 2021-09-18 11:34:35 +08:00
parent c6c3750cab
commit a75f75bbad
2 changed files with 167 additions and 238 deletions

View File

@ -16,19 +16,23 @@
# limitations under the License. # limitations under the License.
import copy import copy
import random
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple, List
import torch_flow_sampling import torch_flow_sampling
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from subsampling import Conv2dSubsampling, VggSubsampling
from transformer import Supervisions, TransformerEncoderLayer, TransformerDecoderLayer, encoder_padding_mask, \ from transformer import Supervisions, TransformerEncoderLayer, TransformerDecoderLayer, encoder_padding_mask, \
LabelSmoothingLoss, PositionalEncoding, pad_sequence, add_sos, add_eos, decoder_padding_mask LabelSmoothingLoss, PositionalEncoding, pad_sequence, add_sos, add_eos, decoder_padding_mask, \
generate_square_subsequent_mask
class ConformerTrunk(nn.Module): class ConformerTrunk(nn.Module):
def __init__(num_features: int, def __init__(self,
num_features: int,
subsampling_factor: int = 4, subsampling_factor: int = 4,
d_model: int = 256, d_model: int = 256,
nhead: int = 4, nhead: int = 4,
@ -37,6 +41,7 @@ class ConformerTrunk(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
use_feat_batchnorm: bool = True) -> None: use_feat_batchnorm: bool = True) -> None:
super(ConformerTrunk, self).__init__()
if use_feat_batchnorm: if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features) self.feat_batchnorm = nn.BatchNorm1d(num_features)
@ -62,12 +67,11 @@ class ConformerTrunk(nn.Module):
cnn_module_kernel, cnn_module_kernel,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_layers)
def forward( def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None self, x: torch.Tensor, supervision: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" """
Args: Args:
x: x:
@ -78,12 +82,10 @@ class ConformerTrunk(nn.Module):
(CAUTION: It contains length information, i.e., start and number of (CAUTION: It contains length information, i.e., start and number of
frames, before subsampling) frames, before subsampling)
Returns: Return (output, pos_emb, mask), where:
Return a tuple containing 2 tensors: output: The output embedding, of shape (T, N, C).
- Encoder output with shape [T, N, C]. It can be used as key and pos_emb: The positional embedding (this will be used by ctc_encoder forward).
value for the decoder. mask: The output padding mask, a Tensor of bool, of shape [N, T].
- Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is [N, T].
It is None if `supervision` is None. It is None if `supervision` is None.
""" """
if hasattr(self, 'feat_batchnorm'): if hasattr(self, 'feat_batchnorm'):
@ -92,13 +94,15 @@ class ConformerTrunk(nn.Module):
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
x = self.feat_embed(x) x = self.feat_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervisions)
mask = mask.to(x.device) if mask is not None else None
x = self.encoder(x, key_padding_mask=mask) # (T, N, C)
return x, mask x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervision)
mask = mask.to(x.device) if mask is not None else None
x = self.encoder(x, pos_emb=pos_emb, key_padding_mask=mask) # (T, N, C)
return x, pos_emb, mask
class BidirectionalConformer(nn.Module): class BidirectionalConformer(nn.Module):
@ -178,25 +182,26 @@ class BidirectionalConformer(nn.Module):
in the discrete bottleneck in the discrete bottleneck
""" """
def __init__( def __init__(
num_features: int, self,
num_classes: int, num_features: int,
subsampling_factor: int = 4, num_classes: int,
d_model: int = 256, subsampling_factor: int = 4,
nhead: int = 4, d_model: int = 256,
dim_feedforward: int = 2048, nhead: int = 4,
num_trunk_encoder_layers: int = 12, dim_feedforward: int = 2048,
num_ctc_encoder_layers: int = 4, num_trunk_encoder_layers: int = 12,
num_decoder_layers: int = 6, num_ctc_encoder_layers: int = 4,
num_reverse_encoder_layers: int = 4, num_decoder_layers: int = 6,
num_reverse_decoder_layers: int = 4, num_reverse_encoder_layers: int = 4,
num_self_predictor_layers: int = 3, num_reverse_decoder_layers: int = 4,
bypass_bottleneck: bool = True, num_self_predictor_layers: int = 3,
dropout: float = 0.1, bypass_bottleneck: bool = True,
cnn_module_kernel: int = 31, dropout: float = 0.1,
is_bpe: bool = False, cnn_module_kernel: int = 31,
use_feat_batchnorm: bool = True, is_bpe: bool = False,
discrete_bottleneck_tot_classes: int = 512, use_feat_batchnorm: bool = True,
discrete_bottleneck_num_groups: int = 4 discrete_bottleneck_tot_classes: int = 512,
discrete_bottleneck_num_groups: int = 4
) -> None: ) -> None:
super(BidirectionalConformer, self).__init__() super(BidirectionalConformer, self).__init__()
@ -236,7 +241,7 @@ class BidirectionalConformer(nn.Module):
self.token_embed_scale = d_model ** 0.5 self.token_embed_scale = d_model ** 0.5
self.token_embed = nn.Embedding( self.token_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model, num_embeddings=self.decoder_num_class, embedding_dim=d_model,
_weight=torch.randn(self.num_classes, d_model) * (1 / self.embed_scale) _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.token_embed_scale)
) )
decoder_layer = TransformerDecoderLayer( decoder_layer = TransformerDecoderLayer(
@ -315,8 +320,7 @@ class BidirectionalConformer(nn.Module):
if num_self_predictor_layers > 0: if num_self_predictor_layers > 0:
encoder_layer = SimpleCausalEncoderLayer(d_model, encoder_layer = SimpleCausalEncoderLayer(d_model,
dropout=dropout) dropout=dropout)
self.self_predictor_encoder = simple_causal_encoder(encoder_layer, self.self_predictor_encoder = encoder_layer
num_self_predictor_layers)
self.discrete_bottleneck = DiscreteBottleneck( self.discrete_bottleneck = DiscreteBottleneck(
@ -326,8 +330,8 @@ class BidirectionalConformer(nn.Module):
def forward(self, x: Tensor, supervision: Optional[Supervisions], def forward(self, x: Tensor, supervision: Optional[Supervisions] = None,
need_softmax: bool = True) -> Tuple[Tensor, Tensor, Optional[Tensor]]: need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
""" """
Forward function that "encodes" the features. Forward function that "encodes" the features.
@ -343,26 +347,29 @@ class BidirectionalConformer(nn.Module):
If true, the last output ("softmax") will be computed. This can be useful If true, the last output ("softmax") will be computed. This can be useful
in the reverse model, but only necessary if straight_through_scale != 1.0. in the reverse model, but only necessary if straight_through_scale != 1.0.
Returns: (memory, bn_memory, sampled, softmax, key_padding_mask), where: Returns: (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask), where:
`memory` is a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T memory: a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T
is actually a subsampled form of the num_frames of the input `x`. is actually a subsampled form of the num_frames of the input `x`.
If self.bypass_bottleneck, it will be taken before the discrete If self.bypass_bottleneck, it will be taken before the discrete
bottleneck; otherwise, from after. bottleneck; otherwise, from after.
`bn_memory` is the same shape as `memory`, but comes after the discrete bottleneck bn_memory: The same shape as `memory`, but comes after the discrete bottleneck
regardless of the value of self.bypass_bottleneck. regardless of the value of self.bypass_bottleneck.
`sampled` is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` pos_emb: The relative positional embedding; will be given to ctc_encoder_forward()
as given to the constructor. This will be needed for the 'reverse' model. sampled: a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
`softmax` is a "soft" version of `sampled`. Will only be returned if need_softmax == True; as given to the constructor. This will be needed for the 'reverse' model.
softmax: a "soft" version of `sampled`. Will only be returned if need_softmax == True;
else will be None. else will be None.
key_padding_mask: The padding mask for the "memory" output, a Tensor of bool of
shape [N, T] (only if supervision was supplied, else None).
""" """
encoder_output, memory_key_padding_mask = self.trunk(x, supervision) encoder_output, pos_emb, memory_key_padding_mask = self.trunk(x, supervision)
bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output) bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output)
memory = encoder_output if self.bypass_bottleneck else bn_memory memory = encoder_output if self.bypass_bottleneck else bn_memory
return (memory, bn_memory, sampled, softmax, memory_key_padding_mask) return (memory, bn_memory, pos_emb, sampled, softmax, memory_key_padding_mask)
def decoder_forward( def decoder_forward(
self, self,
@ -373,11 +380,13 @@ class BidirectionalConformer(nn.Module):
eos_id: int, eos_id: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute the decoder loss function (given a particular list of hypotheses).
Args: Args:
memory: memory:
It's the first output of forward(), with shape [T, N, E] The first output of forward(), with shape [T, N, E]
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from forward() The padding mask from forward(), a tensor of bool with shape [N, T]
token_ids: token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance. A list-of-list IDs. Each sublist contains IDs for an utterance.
The IDs can be either phone IDs or word piece IDs. The IDs can be either phone IDs or word piece IDs.
@ -390,6 +399,7 @@ class BidirectionalConformer(nn.Module):
A scalar, the **sum** of label smoothing loss over utterances A scalar, the **sum** of label smoothing loss over utterances
in the batch without any normalization. in the batch without any normalization.
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
@ -426,6 +436,7 @@ class BidirectionalConformer(nn.Module):
tgt_mask=tgt_mask, tgt_mask=tgt_mask,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) # (T, N, C) ) # (T, N, C)
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
@ -436,6 +447,7 @@ class BidirectionalConformer(nn.Module):
def ctc_encoder_forward( def ctc_encoder_forward(
self, self,
memory: torch.Tensor, memory: torch.Tensor,
pos_emb: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -444,17 +456,19 @@ class BidirectionalConformer(nn.Module):
Args: Args:
memory: memory:
It's the output of forward(), with shape [T, N, E] It's the output of forward(), with shape (T, N, E)
pos_emb:
Relative positional embedding tensor: (N, 2*T-1, E)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from forward() The padding mask from forward(), a tensor of bool of shape (N, T)
Returns: Returns:
A Tensor with shape [N, T, C] where C is the number of classes A Tensor with shape [N, T, C] where C is the number of classes
(e.g. number of phones or word pieces). Contains normalized (e.g. number of phones or word pieces). Contains normalized
log-probabilities. log-probabilities.
""" """
x = self.ctc_encoder(memory, x = self.ctc_encoder(memory,
pos_emb,
key_padding_mask=memory_key_padding_mask) key_padding_mask=memory_key_padding_mask)
x = self.ctc_output_layer(x) x = self.ctc_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -470,10 +484,10 @@ class BidirectionalConformer(nn.Module):
softmax: Optional[torch.Tensor], softmax: Optional[torch.Tensor],
reverse_gradient: bool = True) -> Tensor: reverse_gradient: bool = True) -> Tensor:
""" """
Returns the total log-prob of the the Returns the total log-prob of the the labels sampled in the discrete
labels sampled in the discrete bottleneck layer, as predicted using a relatively bottleneck layer, as predicted using a relatively simple model that
simple model from previous frames sampled from the bottleneck layer. predicts from previous frames sampled from the bottleneck layer.
[Appears on the denominator of an expressin for mutual information]. [Appears on the denominator of an expression for mutual information].
Args: Args:
memory_shifted: memory_shifted:
@ -482,21 +496,25 @@ class BidirectionalConformer(nn.Module):
(T, N, E) = memory.shape (T, N, E) = memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder, of shape [N, T], boolean, True
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` for masked locations.
as given to the constructor. This will be needed for the 'reverse' model. sampled: sampled and interpolated one-hot values, as a Tensor of shape [T, N, C]
where C corresponds to `discrete_bottleneck_tot_classes`
as given to the constructor. This will be needed for the 'reverse'
model.
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
reverse_gradient: will likely be true. If true, the gradient is reversed twice reverse_gradient: will likely be true. If true, the gradient is reversed twice
in this computation, so that we train predictors with the correct in this computation, so that we train predictors with the correct
gradient, i.e. to predict, not anti-predict (since the return value gradient, i.e. to predict, not anti-predict (since the return value
of this function will appear with positive, not negative, sign in the of this function will appear with positive, not negative, sign in the
loss function, so will be minimized). loss function, so will be minimized).
The gradient w.r.t. the non-self inputs to this function, though (i.e.
memory_shifted, sampled, softmax) will not be reversed, though.
Returns: Returns:
A scalar tensor, the **sum** of label smoothing loss over utterances A scalar tensor, the **sum** of label smoothing loss over utterances
in the batch without any normalization. in the batch without any normalization.
""" """
if reverse_gradient: if reverse_gradient:
# Reversing gradient for memory_shifted puts the gradient back into # Reversing gradient for memory_shifted puts the gradient back into
# the correct sign; we reversed it in # the correct sign; we reversed it in
@ -506,6 +524,8 @@ class BidirectionalConformer(nn.Module):
# what happens to the gradients). # what happens to the gradients).
memory_shifted = ReverseGrad.apply(memory_shifted) memory_shifted = ReverseGrad.apply(memory_shifted)
# no mask is needed for self_predictor_encoder; its CNN
# layer uses left-padding only, making it causal.
predictor = self.self_predictor_encoder(memory_shifted) predictor = self.self_predictor_encoder(memory_shifted)
prob = self.discrete_bottleneck.compute_prob(predictor, prob = self.discrete_bottleneck.compute_prob(predictor,
@ -566,15 +586,21 @@ class BidirectionalConformer(nn.Module):
tokens_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=padding_id) tokens_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=padding_id)
# Let S be the length of the longest sentence (padded)
token_embedding = self.token_embed(tokens_padded) * self.token_embed_scale # (N, S) -> (N, S, C)
# add absolute position-encoding information
token_embedding = self.abs_pos(token_embedding)
token_embedding = token_embedding.permute(1, 0, 2) # (N, S, C) -> (S, N, C)
token_memory = self.reverse_encoder(token_embedding,
src_key_padding_mask=tokens_key_padding_mask)
# token_memory is of shape (S, N, C), if S is length of token sequence.
T = memory.shape[0] T = memory.shape[0]
# the targets, here, are the hidden discrete symbols we are predicting # the targets, here, are the hidden discrete symbols we are predicting
tgt_mask = generate_square_subsequent_mask(T, device=memory.device) tgt_mask = generate_square_subsequent_mask(T, device=memory.device)
token_embedding = self.token_embed(tokens_padded) * self.token_embed_scale
token_memory = self.reverse_encoder(token_embedding,
src_key_padding_mask=tokens_key_padding_mask)
# tokens_encoded is of shape (S, N, C), if S is length of token sequence.
hidden_predictor = self.reverse_decoder( hidden_predictor = self.reverse_decoder(
tgt=memory_shifted, tgt=memory_shifted,
memory=token_memory, memory=token_memory,
@ -584,25 +610,13 @@ class BidirectionalConformer(nn.Module):
total_prob = self.discrete_bottleneck.compute_prob( total_prob = self.discrete_bottleneck.compute_prob(
hidden_predictor, hidden_predictor,
TODO, # HERE sampled,
) softmax,
memory_key_padding_mask)
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) # TODO: consider using a label-smoothed loss.
tgt = self.decoder_pos(tgt) return total_prob
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, N, C)
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
return decoder_loss
class SimpleCausalEncoderLayer(nn.Module): class SimpleCausalEncoderLayer(nn.Module):
""" """
@ -650,129 +664,6 @@ class SimpleCausalEncoderLayer(nn.Module):
src = self.norm_final(src) src = self.norm_final(src)
return src return src
# for search: SimpleCausalEncoder
def simple_causal_encoder(encoder_layer: nn.Module,
num_layers: int):
return torch.nn.Sequential([copy.deepcopy(encoder_leyer) for _ in range(num_layers)])
class DiscreteBottleneckConformer(Transformer):
"""
Args:
num_features (int): Number of input features
num_classes (int): Number of output classes
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
discrete_bottleneck_pos (int): position in the encoder at which to place
the discrete bottleneck (this many encoder layers will
precede it)
"""
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
num_decoder_layers: int = 6,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
is_espnet_structure: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
discrete_bottleneck_pos: int = 8,
discrete_bottleneck_tot_classes: int = 512,
discrete_bottleneck_num_groups: int = 2
) -> None:
super(DiscreteBottleneckConformer, self).__init__(
num_features=num_features,
num_classes=num_classes,
subsampling_factor=subsampling_factor,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
mmi_loss=mmi_loss,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
cnn_module_kernel,
normalize_before,
is_espnet_structure,
)
discrete_bottleneck = DiscreteBottleneck(dim=d_model,
tot_classes=discrete_bottleneck_tot_classes,
num_groups=discrete_bottleneck_num_groups)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
discrete_bottleneck=discrete_bottleneck,
discrete_bottleneck_pos=discrete_bottleneck_pos)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x:
The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = self.feat_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
mask = encoder_padding_mask(x.size(0), supervisions)
if mask is not None:
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
return x, mask
class ReverseGrad(torch.autograd.Function): class ReverseGrad(torch.autograd.Function):
def apply(ctx, x): def apply(ctx, x):
@ -878,9 +769,9 @@ class DiscreteBottleneck(nn.Module):
# [ True, True, True, True]]) # [ True, True, True, True]])
self.register_buffer('pred_cross_mask', self.register_buffer('pred_cross_mask',
((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0))) ((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0)))
self.reset_parameters() self._reset_parameters()
def reset_parameters_(self): def _reset_parameters(self):
if hasattr(self, 'pred_cross'): if hasattr(self, 'pred_cross'):
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
@ -967,7 +858,8 @@ class DiscreteBottleneck(nn.Module):
This will be useful in computing a loss function that has This will be useful in computing a loss function that has
a likelihood term with negative sign (i.e. the self-prediction). a likelihood term with negative sign (i.e. the self-prediction).
We'll later need negate the gradient one more more time We'll later need negate the gradient one more more time
where we give the input to whatever module generated 'x'. where we give the input to the prediction module that
generated 'x'.
Returns a scalar Tensor represnting the total probability. Returns a scalar Tensor represnting the total probability.
""" """
@ -980,6 +872,8 @@ class DiscreteBottleneck(nn.Module):
logprobs = self.pred_linear(x) logprobs = self.pred_linear(x)
# Add "cross-terms" to logprobs; this is a regression that uses earlier
# groups to predict later groups
if self.num_groups > 1: if self.num_groups > 1:
pred_cross = self.pred_cross * self.pred_cross_mask pred_cross = self.pred_cross * self.pred_cross_mask
t = self.tot_classes t = self.tot_classes
@ -999,11 +893,12 @@ class DiscreteBottleneck(nn.Module):
logprobs = logprobs.reshape(S, N, self.num_groups, self.classes_per_group) logprobs = logprobs.reshape(S, N, self.num_groups, self.classes_per_group)
# Normalize the log-probs (so they sum to one) # Normalize the log-probs (so they sum to one)
logprobs = torch.nn.functional.logsoftmax(logprobs, dim=-1) logprobs = torch.nn.functional.logsoftmax(logprobs, dim=-1)
logprobs = logprobs.reshape(S, N, C)
if padding_mask is not None: if padding_mask is not None:
assert padding_mask.dtype == torch.bool and padding_mask.shape == (N, S) assert padding_mask.dtype == torch.bool and padding_mask.shape == (N, S)
padding_mask = torch.logical_not(padding_mask).transpose(0, 1).unsqueeze(-1) padding_mask = torch.logical_not(padding_mask).transpose(0, 1).unsqueeze(-1)
# padding_mask.shape == (S, N, E) assert padding_mask.shape == (S, N, 1)
tot_prob = (logprobs * softmax * padding_mask).sum() tot_prob = (logprobs * softmax * padding_mask).sum()
else: else:
tot_prob = (logprobs * softmax).sum() tot_prob = (logprobs * softmax).sum()
@ -1080,8 +975,8 @@ class ConformerEncoderLayer(nn.Module):
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
src_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -1089,17 +984,18 @@ class ConformerEncoderLayer(nn.Module):
Args: Args:
src: the sequence to the encoder layer (required). src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional). attn_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). key_padding_mask: the mask for the src keys per batch (optional).
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E)
src_mask: (S, S). attn_mask: (S, S). This probably won't be used, in fact should not
src_key_padding_mask: (N, S). be (e.g. could in principle ensure causal behavior, but
actually the conformer does not support this).
key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
""" """
# macaron style feed forward module # macaron style feed forward module
residual = src residual = src
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -1115,8 +1011,8 @@ class ConformerEncoderLayer(nn.Module):
src, src,
src, src,
pos_emb=pos_emb, pos_emb=pos_emb,
attn_mask=src_mask, attn_mask=attn_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=key_padding_mask,
)[0] )[0]
src = residual + self.dropout(src_att) src = residual + self.dropout(src_att)
@ -1141,7 +1037,6 @@ class ConformerEncoder(nn.Module):
Args: Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required). encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required). num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples:: Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -1153,7 +1048,7 @@ class ConformerEncoder(nn.Module):
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
super(ConformerEncoder, self).__init__() super(ConformerEncoder, self).__init__()
self.layers = torch.nn.ModuleList([copy.deepcopy(encoder_leyer) for _ in range(num_layers)]) self.layers = torch.nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
def forward( def forward(
self, self,
@ -1183,7 +1078,6 @@ class ConformerEncoder(nn.Module):
return x return x
class RelPositionalEncoding(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module. """Relative positional encoding module.
@ -1313,7 +1207,6 @@ class RelPositionMultiheadAttention(nn.Module):
nn.init.xavier_uniform_(self.in_proj.weight) nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0) nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0) nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u) nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v) nn.init.xavier_uniform_(self.pos_bias_v)
@ -1663,6 +1556,7 @@ class RelPositionMultiheadAttention(nn.Module):
) )
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout( attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training attn_output_weights, p=dropout_p, training=training
) )
@ -1701,7 +1595,7 @@ class CausalConvolutionModule(nn.Module):
self, channels: int, kernel_size: int, bias: bool = True self, channels: int, kernel_size: int, bias: bool = True
) -> None: ) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(CausalConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -1752,11 +1646,12 @@ class CausalConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time) x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv # 1D Depthwise Conv
(B, C, T) = x (B, C, T) = x.shape
padding = self.kernel_size - 1 padding = self.kernel_size - 1
x = torch.cat((torch.zeros(B, C, padding, device=x.device, dtype=x.type), x), x = torch.cat((torch.zeros(B, C, padding, device=x.device, dtype=x.type), x),
dim=2) dim=2)
x = self.depthwise_conv(x) # <-- This has no padding. x = self.depthwise_conv(x) # <-- This convolution module does no padding,
# so we padded manually, on the left only.
x = self.activation(self.norm(x)) x = self.activation(self.norm(x))
@ -1835,7 +1730,7 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1) return x.permute(2, 0, 1) # (time, batch channel)
class Swish(torch.nn.Module): class Swish(torch.nn.Module):
@ -1851,16 +1746,47 @@ def identity(x):
def test_discrete_bottleneck_conformer(): def _gen_rand_tokens(N: int) -> List[List[int]]:
ans = []
for _ in range(N):
S = random.randint(1, 20)
ans.append([random.randint(3, 30) for _ in range(S)])
return ans
def _gen_supervision(tokens: List[List[int]]):
ans = dict()
N = len(tokens)
ans['sequence_idx'] = torch.arange(N, dtype=torch.int32)
ans['start_frame'] = torch.zeros(N, dtype=torch.int32)
ans['num_frames'] = torch.tensor([ random.randint(20, 35) for _ in tokens])
return ans
def _test_bidirectional_conformer():
num_features = 40 num_features = 40
num_classes = 1000 num_classes = 1000
m = DiscreteBottleneckConformer(num_features, num_classes) m = BidirectionalConformer(num_features, num_classes)
T = 35 T = 35
N = 10 N = 10
C = num_features C = num_features
feats = torch.randn(N, T, C) feats = torch.randn(N, T, C)
ctc_output, _, _ = m(feats)
# [N, T, C]. tokens = _gen_rand_tokens(N)
supervision = _gen_supervision(tokens)
print("tokens = ", tokens)
print("supervision = ", supervision)
# memory: [T, N, C]
(memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask) = m(feats, supervision)
# ctc_output: [N, T, C].
ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask)
decoder_loss = m.decoder_forward(memory, key_padding_mask, tokens,
sos_id=1,
eos_id=2)
(T, N, E) = memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
if __name__ == '__main__': if __name__ == '__main__':
test_discrete_bottleneck_conformer() _test_bidirectional_conformer()

View File

@ -914,6 +914,9 @@ def encoder_padding_mask(
).unsqueeze(-1) ).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand mask = seq_range_expand >= seq_length_expand
# Assert that in each row, i.e. each utterance, at least one frame is not
# masked. Otherwise it may lead to nan's appearing in the attention computation.
assert torch.all(torch.sum(torch.logical_not(mask), dim=1) != 0)
return mask return mask