mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
First version of conformer with discrete bottleneck
This commit is contained in:
parent
44b33b7f05
commit
dfe773aa78
@ -19,13 +19,14 @@
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
import torch_flow_sampling
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
class DiscreteBottleneckConformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
@ -40,6 +41,10 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
normalize_before (bool): whether to use layer_norm before the first block.
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
discrete_bottleneck_pos (int): position in the encoder at which to place
|
||||
the discrete bottleneck (this many encoder layers will
|
||||
precede it)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -59,8 +64,11 @@ class Conformer(Transformer):
|
||||
is_espnet_structure: bool = False,
|
||||
mmi_loss: bool = True,
|
||||
use_feat_batchnorm: bool = False,
|
||||
discrete_bottleneck_pos: int = 8,
|
||||
discrete_bottleneck_tot_classes: int = 512,
|
||||
discrete_bottleneck_num_groups: int = 2
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
super(DiscreteBottleneckConformer, self).__init__(
|
||||
num_features=num_features,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=subsampling_factor,
|
||||
@ -87,7 +95,16 @@ class Conformer(Transformer):
|
||||
normalize_before,
|
||||
is_espnet_structure,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
discrete_bottleneck = DiscreteBottleneck(dim=d_model,
|
||||
tot_classes=discrete_bottleneck_tot_classes,
|
||||
num_groups=discrete_bottleneck_num_groups)
|
||||
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||
discrete_bottleneck=discrete_bottleneck,
|
||||
discrete_bottleneck_pos=discrete_bottleneck_pos)
|
||||
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
self.is_espnet_structure = is_espnet_structure
|
||||
if self.normalize_before and self.is_espnet_structure:
|
||||
@ -131,6 +148,112 @@ class Conformer(Transformer):
|
||||
return x, mask
|
||||
|
||||
|
||||
|
||||
class DiscreteBottleneck(nn.Module):
|
||||
"""
|
||||
This layer forces its input through an information bottleneck via
|
||||
a discretization operation with sampling. We use the torch-flow-sampling
|
||||
package for this, to provide a differentiable softmax that should be
|
||||
much better than Gumbel in terms of actually giving an information
|
||||
bottleneck.
|
||||
|
||||
Args:
|
||||
dim: The input and output dimension of the discrete bottleneck
|
||||
operation.
|
||||
tot_classes: The total number of classes (across all groups
|
||||
of classes); each group is separately discretized
|
||||
num_groups: The number of groups of classes; discretization
|
||||
is done separately within each group.
|
||||
interp_prob: The probability with which we interpolate
|
||||
between two classes, assuming the sampling picks
|
||||
to distinct classes. Making this smaller would give
|
||||
more noisy derivatives but makes the operation closer
|
||||
to a true sampling operation. However, even with
|
||||
interp_prob = 1.0, as the distribution gets very
|
||||
peaky we'll still mostly have a single class in
|
||||
the output.
|
||||
straight_through_scale: The scale on the "straight-through"
|
||||
derivatives, in which we treat the softmax derivatives
|
||||
as the derivatives of the softmax+sampling+interpolation
|
||||
operation. This, and interp_prob, may need to be
|
||||
changed as you train, just directly set the
|
||||
variable self._straight_through_scale if you need to.
|
||||
The "true" derivative will be scaled as
|
||||
1.0 - straight_through_scale.
|
||||
min_prob_ratio: For any class whose average softmax
|
||||
output, for a given minibatch, is less than
|
||||
min_prob_ratio times
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
tot_classes: int,
|
||||
num_groups: int,
|
||||
interp_prob: float = 1.0,
|
||||
straight_through_scale: float = 0.333,
|
||||
min_prob_ratio: float = 0.1
|
||||
):
|
||||
super(DiscreteBottleneck, self).__init__()
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
self.linear1 = nn.Linear(dim, tot_classes)
|
||||
|
||||
self.num_groups = num_groups
|
||||
self.interp_prob = interp_prob
|
||||
self.straight_through_scale = straight_through_scale
|
||||
self.min_prob_ratio = min_prob_ratio
|
||||
self.tot_classes = tot_classes
|
||||
self.classes_per_group = tot_classes // num_groups
|
||||
self.prob_boost = 1.0e-05
|
||||
|
||||
# class_probs is a rolling mean of the output of the sampling operation.
|
||||
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
|
||||
# we boost the class's probability by adding self.prob_boost to
|
||||
# that element of self.class_offset
|
||||
self.class_probs_decay = 0.9
|
||||
self.register_buffer('class_probs', torch.ones(tot_classes) / self.classes_per_group)
|
||||
# class_offsets is a bias term that we add to logits before the sampling
|
||||
# operation in order to enforce that no class is too infrequent
|
||||
# (c.f. 'min_prob_ratio').
|
||||
self.register_buffer('class_offsets', torch.zeros(tot_classes))
|
||||
|
||||
self.linear2 = nn.Linear(tot_classes, dim)
|
||||
self.norm_out = nn.LayerNorm(dim)
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Forward computation.
|
||||
Args:
|
||||
x: The input tensor, of shape (S, N, E) where S is the sequence length,
|
||||
N is the batch size and E is the embedding dim.
|
||||
"""
|
||||
x = self.norm_in(x)
|
||||
x = self.linear1(x)
|
||||
x = x + self.class_offsets
|
||||
|
||||
(S, N, tot_classes) = x.shape
|
||||
x = x.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
|
||||
x = torch_flow_sampling.flow_sample(x,
|
||||
interp_prob=self.interp_prob,
|
||||
straight_through_scale=self.straight_through_scale)
|
||||
|
||||
assert x.shape == (S, N, self.num_groups, self.classes_per_group)
|
||||
x = x.reshape(S, N, tot_classes)
|
||||
|
||||
if self.training:
|
||||
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
|
||||
self.class_probs = (self.class_probs * self.class_probs_decay +
|
||||
mean_class_probs * (1.0 - self.class_probs_decay))
|
||||
prob_floor = self.min_prob_ratio / self.classes_per_group
|
||||
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost
|
||||
|
||||
x = self.linear2(x)
|
||||
x = self.norm_out(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||
@ -288,11 +411,15 @@ class ConformerEncoder(nn.TransformerEncoder):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
|
||||
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None,
|
||||
discrete_bottleneck: Optional[nn.Module] = None,
|
||||
discrete_bottleneck_pos: Optional[int] = None
|
||||
) -> None:
|
||||
super(ConformerEncoder, self).__init__(
|
||||
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
|
||||
)
|
||||
self.discrete_bottleneck = discrete_bottleneck
|
||||
self.discrete_bottleneck_pos = discrete_bottleneck_pos
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -319,7 +446,9 @@ class ConformerEncoder(nn.TransformerEncoder):
|
||||
"""
|
||||
output = src
|
||||
|
||||
for mod in self.layers:
|
||||
for i, mod in enumerate(self.layers):
|
||||
if i == self.discrete_bottleneck_pos:
|
||||
output = self.discrete_bottleneck(output)
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
@ -931,3 +1060,19 @@ class Swish(torch.nn.Module):
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def test_discrete_bottleneck_conformer():
|
||||
num_features = 40
|
||||
num_classes = 1000
|
||||
m = DiscreteBottleneckConformer(num_features, num_classes)
|
||||
T = 35
|
||||
N = 10
|
||||
C = num_features
|
||||
feats = torch.randn(N, T, C)
|
||||
ctc_output, _, _ = m(feats)
|
||||
# [N, T, C].
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_discrete_bottleneck_conformer()
|
||||
|
@ -28,7 +28,7 @@ import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from conformer import DiscreteBottleneckConformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
@ -150,7 +150,7 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp_gloam_5e-4_0.85"),
|
||||
"exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
@ -647,7 +647,7 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = Conformer(
|
||||
model = DiscreteBottleneckConformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
|
Loading…
x
Reference in New Issue
Block a user