First version of conformer with discrete bottleneck

This commit is contained in:
Daniel Povey 2021-09-10 18:51:16 +08:00
parent 44b33b7f05
commit dfe773aa78
2 changed files with 153 additions and 8 deletions

View File

@ -19,13 +19,14 @@
import math
import warnings
from typing import Optional, Tuple
import torch_flow_sampling
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):
class DiscreteBottleneckConformer(Transformer):
"""
Args:
num_features (int): Number of input features
@ -40,6 +41,10 @@ class Conformer(Transformer):
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
discrete_bottleneck_pos (int): position in the encoder at which to place
the discrete bottleneck (this many encoder layers will
precede it)
"""
def __init__(
@ -59,8 +64,11 @@ class Conformer(Transformer):
is_espnet_structure: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
discrete_bottleneck_pos: int = 8,
discrete_bottleneck_tot_classes: int = 512,
discrete_bottleneck_num_groups: int = 2
) -> None:
super(Conformer, self).__init__(
super(DiscreteBottleneckConformer, self).__init__(
num_features=num_features,
num_classes=num_classes,
subsampling_factor=subsampling_factor,
@ -87,7 +95,16 @@ class Conformer(Transformer):
normalize_before,
is_espnet_structure,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
discrete_bottleneck = DiscreteBottleneck(dim=d_model,
tot_classes=discrete_bottleneck_tot_classes,
num_groups=discrete_bottleneck_num_groups)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
discrete_bottleneck=discrete_bottleneck,
discrete_bottleneck_pos=discrete_bottleneck_pos)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
@ -131,6 +148,112 @@ class Conformer(Transformer):
return x, mask
class DiscreteBottleneck(nn.Module):
"""
This layer forces its input through an information bottleneck via
a discretization operation with sampling. We use the torch-flow-sampling
package for this, to provide a differentiable softmax that should be
much better than Gumbel in terms of actually giving an information
bottleneck.
Args:
dim: The input and output dimension of the discrete bottleneck
operation.
tot_classes: The total number of classes (across all groups
of classes); each group is separately discretized
num_groups: The number of groups of classes; discretization
is done separately within each group.
interp_prob: The probability with which we interpolate
between two classes, assuming the sampling picks
to distinct classes. Making this smaller would give
more noisy derivatives but makes the operation closer
to a true sampling operation. However, even with
interp_prob = 1.0, as the distribution gets very
peaky we'll still mostly have a single class in
the output.
straight_through_scale: The scale on the "straight-through"
derivatives, in which we treat the softmax derivatives
as the derivatives of the softmax+sampling+interpolation
operation. This, and interp_prob, may need to be
changed as you train, just directly set the
variable self._straight_through_scale if you need to.
The "true" derivative will be scaled as
1.0 - straight_through_scale.
min_prob_ratio: For any class whose average softmax
output, for a given minibatch, is less than
min_prob_ratio times
"""
def __init__(
self,
dim: int,
tot_classes: int,
num_groups: int,
interp_prob: float = 1.0,
straight_through_scale: float = 0.333,
min_prob_ratio: float = 0.1
):
super(DiscreteBottleneck, self).__init__()
self.norm_in = nn.LayerNorm(dim)
self.linear1 = nn.Linear(dim, tot_classes)
self.num_groups = num_groups
self.interp_prob = interp_prob
self.straight_through_scale = straight_through_scale
self.min_prob_ratio = min_prob_ratio
self.tot_classes = tot_classes
self.classes_per_group = tot_classes // num_groups
self.prob_boost = 1.0e-05
# class_probs is a rolling mean of the output of the sampling operation.
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
# we boost the class's probability by adding self.prob_boost to
# that element of self.class_offset
self.class_probs_decay = 0.9
self.register_buffer('class_probs', torch.ones(tot_classes) / self.classes_per_group)
# class_offsets is a bias term that we add to logits before the sampling
# operation in order to enforce that no class is too infrequent
# (c.f. 'min_prob_ratio').
self.register_buffer('class_offsets', torch.zeros(tot_classes))
self.linear2 = nn.Linear(tot_classes, dim)
self.norm_out = nn.LayerNorm(dim)
def forward(self, x: Tensor) -> Tensor:
"""
Forward computation.
Args:
x: The input tensor, of shape (S, N, E) where S is the sequence length,
N is the batch size and E is the embedding dim.
"""
x = self.norm_in(x)
x = self.linear1(x)
x = x + self.class_offsets
(S, N, tot_classes) = x.shape
x = x.reshape(S, N, self.num_groups, self.classes_per_group)
x = torch_flow_sampling.flow_sample(x,
interp_prob=self.interp_prob,
straight_through_scale=self.straight_through_scale)
assert x.shape == (S, N, self.num_groups, self.classes_per_group)
x = x.reshape(S, N, tot_classes)
if self.training:
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
self.class_probs = (self.class_probs * self.class_probs_decay +
mean_class_probs * (1.0 - self.class_probs_decay))
prob_floor = self.min_prob_ratio / self.classes_per_group
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost
x = self.linear2(x)
x = self.norm_out(x)
return x
class ConformerEncoderLayer(nn.Module):
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
@ -288,11 +411,15 @@ class ConformerEncoder(nn.TransformerEncoder):
"""
def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None,
discrete_bottleneck: Optional[nn.Module] = None,
discrete_bottleneck_pos: Optional[int] = None
) -> None:
super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
)
self.discrete_bottleneck = discrete_bottleneck
self.discrete_bottleneck_pos = discrete_bottleneck_pos
def forward(
self,
@ -319,7 +446,9 @@ class ConformerEncoder(nn.TransformerEncoder):
"""
output = src
for mod in self.layers:
for i, mod in enumerate(self.layers):
if i == self.discrete_bottleneck_pos:
output = self.discrete_bottleneck(output)
output = mod(
output,
pos_emb,
@ -931,3 +1060,19 @@ class Swish(torch.nn.Module):
def identity(x):
return x
def test_discrete_bottleneck_conformer():
num_features = 40
num_classes = 1000
m = DiscreteBottleneckConformer(num_features, num_classes)
T = 35
N = 10
C = num_features
feats = torch.randn(N, T, C)
ctc_output, _, _ = m(feats)
# [N, T, C].
if __name__ == '__main__':
test_discrete_bottleneck_conformer()

View File

@ -28,7 +28,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from conformer import DiscreteBottleneckConformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
@ -150,7 +150,7 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp_gloam_5e-4_0.85"),
"exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"subsampling_factor": 4,
@ -647,7 +647,7 @@ def run(rank, world_size, args):
)
logging.info("About to create model")
model = Conformer(
model = DiscreteBottleneckConformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,