From dfe773aa78e8a64a237e51b7e996f70dfb8b3962 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Sep 2021 18:51:16 +0800 Subject: [PATCH] First version of conformer with discrete bottleneck --- .../ASR/conformer_ctc_bn/conformer.py | 155 +++++++++++++++++- egs/librispeech/ASR/conformer_ctc_bn/train.py | 6 +- 2 files changed, 153 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_bn/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn/conformer.py index 08287d686..566cad8cf 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn/conformer.py @@ -19,13 +19,14 @@ import math import warnings from typing import Optional, Tuple +import torch_flow_sampling import torch from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask -class Conformer(Transformer): +class DiscreteBottleneckConformer(Transformer): """ Args: num_features (int): Number of input features @@ -40,6 +41,10 @@ class Conformer(Transformer): cnn_module_kernel (int): Kernel size of convolution module normalize_before (bool): whether to use layer_norm before the first block. vgg_frontend (bool): whether to use vgg frontend. + discrete_bottleneck_pos (int): position in the encoder at which to place + the discrete bottleneck (this many encoder layers will + precede it) + """ def __init__( @@ -59,8 +64,11 @@ class Conformer(Transformer): is_espnet_structure: bool = False, mmi_loss: bool = True, use_feat_batchnorm: bool = False, + discrete_bottleneck_pos: int = 8, + discrete_bottleneck_tot_classes: int = 512, + discrete_bottleneck_num_groups: int = 2 ) -> None: - super(Conformer, self).__init__( + super(DiscreteBottleneckConformer, self).__init__( num_features=num_features, num_classes=num_classes, subsampling_factor=subsampling_factor, @@ -87,7 +95,16 @@ class Conformer(Transformer): normalize_before, is_espnet_structure, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + + discrete_bottleneck = DiscreteBottleneck(dim=d_model, + tot_classes=discrete_bottleneck_tot_classes, + num_groups=discrete_bottleneck_num_groups) + + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, + discrete_bottleneck=discrete_bottleneck, + discrete_bottleneck_pos=discrete_bottleneck_pos) + + self.normalize_before = normalize_before self.is_espnet_structure = is_espnet_structure if self.normalize_before and self.is_espnet_structure: @@ -131,6 +148,112 @@ class Conformer(Transformer): return x, mask + +class DiscreteBottleneck(nn.Module): + """ + This layer forces its input through an information bottleneck via + a discretization operation with sampling. We use the torch-flow-sampling + package for this, to provide a differentiable softmax that should be + much better than Gumbel in terms of actually giving an information + bottleneck. + + Args: + dim: The input and output dimension of the discrete bottleneck + operation. + tot_classes: The total number of classes (across all groups + of classes); each group is separately discretized + num_groups: The number of groups of classes; discretization + is done separately within each group. + interp_prob: The probability with which we interpolate + between two classes, assuming the sampling picks + to distinct classes. Making this smaller would give + more noisy derivatives but makes the operation closer + to a true sampling operation. However, even with + interp_prob = 1.0, as the distribution gets very + peaky we'll still mostly have a single class in + the output. + straight_through_scale: The scale on the "straight-through" + derivatives, in which we treat the softmax derivatives + as the derivatives of the softmax+sampling+interpolation + operation. This, and interp_prob, may need to be + changed as you train, just directly set the + variable self._straight_through_scale if you need to. + The "true" derivative will be scaled as + 1.0 - straight_through_scale. + min_prob_ratio: For any class whose average softmax + output, for a given minibatch, is less than + min_prob_ratio times + """ + def __init__( + self, + dim: int, + tot_classes: int, + num_groups: int, + interp_prob: float = 1.0, + straight_through_scale: float = 0.333, + min_prob_ratio: float = 0.1 + ): + super(DiscreteBottleneck, self).__init__() + self.norm_in = nn.LayerNorm(dim) + self.linear1 = nn.Linear(dim, tot_classes) + + self.num_groups = num_groups + self.interp_prob = interp_prob + self.straight_through_scale = straight_through_scale + self.min_prob_ratio = min_prob_ratio + self.tot_classes = tot_classes + self.classes_per_group = tot_classes // num_groups + self.prob_boost = 1.0e-05 + + # class_probs is a rolling mean of the output of the sampling operation. + # When any element of it gets below self.min_prob_ratio / self.classes_per_group, + # we boost the class's probability by adding self.prob_boost to + # that element of self.class_offset + self.class_probs_decay = 0.9 + self.register_buffer('class_probs', torch.ones(tot_classes) / self.classes_per_group) + # class_offsets is a bias term that we add to logits before the sampling + # operation in order to enforce that no class is too infrequent + # (c.f. 'min_prob_ratio'). + self.register_buffer('class_offsets', torch.zeros(tot_classes)) + + self.linear2 = nn.Linear(tot_classes, dim) + self.norm_out = nn.LayerNorm(dim) + + + def forward(self, x: Tensor) -> Tensor: + """ + Forward computation. + Args: + x: The input tensor, of shape (S, N, E) where S is the sequence length, + N is the batch size and E is the embedding dim. + """ + x = self.norm_in(x) + x = self.linear1(x) + x = x + self.class_offsets + + (S, N, tot_classes) = x.shape + x = x.reshape(S, N, self.num_groups, self.classes_per_group) + + x = torch_flow_sampling.flow_sample(x, + interp_prob=self.interp_prob, + straight_through_scale=self.straight_through_scale) + + assert x.shape == (S, N, self.num_groups, self.classes_per_group) + x = x.reshape(S, N, tot_classes) + + if self.training: + mean_class_probs = torch.mean(x.detach(), dim=(0,1)) + self.class_probs = (self.class_probs * self.class_probs_decay + + mean_class_probs * (1.0 - self.class_probs_decay)) + prob_floor = self.min_prob_ratio / self.classes_per_group + self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost + + x = self.linear2(x) + x = self.norm_out(x) + return x + + + class ConformerEncoderLayer(nn.Module): """ ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. @@ -288,11 +411,15 @@ class ConformerEncoder(nn.TransformerEncoder): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None, + discrete_bottleneck: Optional[nn.Module] = None, + discrete_bottleneck_pos: Optional[int] = None ) -> None: super(ConformerEncoder, self).__init__( encoder_layer=encoder_layer, num_layers=num_layers, norm=norm ) + self.discrete_bottleneck = discrete_bottleneck + self.discrete_bottleneck_pos = discrete_bottleneck_pos def forward( self, @@ -319,7 +446,9 @@ class ConformerEncoder(nn.TransformerEncoder): """ output = src - for mod in self.layers: + for i, mod in enumerate(self.layers): + if i == self.discrete_bottleneck_pos: + output = self.discrete_bottleneck(output) output = mod( output, pos_emb, @@ -931,3 +1060,19 @@ class Swish(torch.nn.Module): def identity(x): return x + + + +def test_discrete_bottleneck_conformer(): + num_features = 40 + num_classes = 1000 + m = DiscreteBottleneckConformer(num_features, num_classes) + T = 35 + N = 10 + C = num_features + feats = torch.randn(N, T, C) + ctc_output, _, _ = m(feats) + # [N, T, C]. + +if __name__ == '__main__': + test_discrete_bottleneck_conformer() diff --git a/egs/librispeech/ASR/conformer_ctc_bn/train.py b/egs/librispeech/ASR/conformer_ctc_bn/train.py index 4afe23215..48a58d96c 100755 --- a/egs/librispeech/ASR/conformer_ctc_bn/train.py +++ b/egs/librispeech/ASR/conformer_ctc_bn/train.py @@ -28,7 +28,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer +from conformer import DiscreteBottleneckConformer from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ @@ -150,7 +150,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp_gloam_5e-4_0.85"), + "exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "subsampling_factor": 4, @@ -647,7 +647,7 @@ def run(rank, world_size, args): ) logging.info("About to create model") - model = Conformer( + model = DiscreteBottleneckConformer( num_features=params.feature_dim, nhead=params.nhead, d_model=params.attention_dim,