Revert transducer_stateless/ to state in upstream/master

This commit is contained in:
Daniel Povey 2022-04-02 21:16:39 +08:00
parent 807fcada68
commit 9f62a0296c
7 changed files with 108 additions and 454 deletions

View File

@ -18,8 +18,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple
from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -57,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
aux_layer_period: int = 3
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -82,13 +80,17 @@ class Conformer(Transformer):
cnn_module_kernel, cnn_module_kernel,
normalize_before, normalize_before,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.normalize_before = normalize_before self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -115,8 +117,10 @@ class Conformer(Transformer):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask, x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
warmup_mode=warmup_mode) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
logits = self.encoder_output_layer(x) logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -154,41 +158,42 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0 d_model, nhead, dropout=0.0
) )
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1), Swish(),
DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), nn.Linear(dim_feedforward, d_model),
) )
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1), Swish(),
DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), nn.Linear(dim_feedforward, d_model),
) )
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.norm_final = BasicNorm(d_model) self.ff_scale = 0.5
# try to ensure the output is close to zero-mean (or at least, zero-median). self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.balancer = ActivationBalancer(channel_dim=-1, self.norm_final = nn.LayerNorm(
min_positive=0.45, d_model
max_positive=0.55, ) # for the final output of the block
max_positive=6.0)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward( def forward(
self, self,
@ -215,10 +220,19 @@ class ConformerEncoderLayer(nn.Module):
""" """
# macaron style feed forward module # macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src)) residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module # multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src_att = self.self_attn( src_att = self.self_attn(
src, src,
src, src,
@ -227,15 +241,28 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
)[0] )[0]
src = src + self.dropout(src_att) src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module # convolution module
src = src + self.dropout(self.conv_module(src)) residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module # feed forward module
src = src + self.dropout(self.feed_forward(src)) residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
src = self.norm_final(self.balancer(src)) if self.normalize_before:
src = self.norm_final(src)
return src return src
@ -255,20 +282,12 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb) >>> out = conformer_encoder(src, pos_emb)
""" """
def __init__(self, encoder_layer: nn.Module, num_layers: int, def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
aux_layers: Sequence[int]) -> None:
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
self.aux_layers = set(aux_layers + [num_layers - 1])
assert num_layers - 1 not in aux_layers
self.num_layers = num_layers self.num_layers = num_layers
num_channels = encoder_layer.d_model
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
final_weight=0.5,
pure_prob=0.333,
stddev=2.0)
def forward( def forward(
self, self,
@ -276,7 +295,6 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup_mode: bool = False
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -296,19 +314,14 @@ class ConformerEncoder(nn.Module):
""" """
output = src output = src
outputs = [] for mod in self.layers:
for i, mod in enumerate(self.layers):
output = mod( output = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
if i in self.aux_layers:
outputs.append(output)
output = self.combiner(outputs, warmup_mode)
return output return output
@ -331,6 +344,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate) self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))
@ -382,6 +396,7 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
self.extend_pe(x) self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[ pos_emb = self.pe[
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
@ -413,7 +428,6 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
dropout: float = 0.0, dropout: float = 0.0,
scale_speed: float = 5.0
) -> None: ) -> None:
super(RelPositionMultiheadAttention, self).__init__() super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -424,29 +438,25 @@ class RelPositionMultiheadAttention(nn.Module):
self.head_dim * num_heads == self.embed_dim self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding. # linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d # these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.scale_speed = scale_speed
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters() self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp()
def _pos_bias_v(self):
return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp()
def _reset_parameters(self) -> None: def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.05) nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.normal_(self.pos_bias_v, std=0.05) nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def forward( def forward(
self, self,
@ -506,11 +516,11 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb, pos_emb,
self.embed_dim, self.embed_dim,
self.num_heads, self.num_heads,
self.in_proj.get_weight(), self.in_proj.weight,
self.in_proj.get_bias(), self.in_proj.bias,
self.dropout, self.dropout,
self.out_proj.get_weight(), self.out_proj.weight,
self.out_proj.get_bias(), self.out_proj.bias,
training=self.training, training=self.training,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=need_weights, need_weights=need_weights,
@ -614,12 +624,13 @@ class RelPositionMultiheadAttention(nn.Module):
assert ( assert (
head_dim * num_heads == embed_dim head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5 scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -651,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end] _b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b) q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = embed_dim _start = embed_dim
@ -670,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:] _b = _b[_start:]
v = nn.functional.linear(value, _w, _b) v = nn.functional.linear(value, _w, _b)
if attn_mask is not None: if attn_mask is not None:
assert ( assert (
attn_mask.dtype == torch.float32 attn_mask.dtype == torch.float32
@ -720,7 +729,7 @@ class RelPositionMultiheadAttention(nn.Module):
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
@ -741,11 +750,11 @@ class RelPositionMultiheadAttention(nn.Module):
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self._pos_bias_u()).transpose( q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2 1, 2
) # (batch, head, time1, d_k) ) # (batch, head, time1, d_k)
q_with_bias_v = (q + self._pos_bias_v()).transpose( q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2 1, 2
) # (batch, head, time1, d_k) ) # (batch, head, time1, d_k)
@ -765,7 +774,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output_weights = ( attn_output_weights = (
matrix_ac + matrix_bd matrix_ac + matrix_bd
) # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1 bsz * num_heads, tgt_len, -1
@ -840,7 +849,7 @@ class ConvolutionModule(nn.Module):
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = ScaledConv1d( self.pointwise_conv1 = nn.Conv1d(
channels, channels,
2 * channels, 2 * channels,
kernel_size=1, kernel_size=1,
@ -848,25 +857,7 @@ class ConvolutionModule(nn.Module):
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.depthwise_conv = nn.Conv1d(
# after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
# between 50 and 100 for different channels. This will cause very peaky and
# sparse derivatives for the sigmoid gating function, which will tend to make
# the loss function not learn effectively. (for most layers the average absolute values
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
# layers, which likely breaks down as 0.5 for the "linear" half and
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto
# the correct range.
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
min_positive=0.05,
max_positive=1.0)
self.depthwise_conv = ScaledConv1d(
channels, channels,
channels, channels,
kernel_size, kernel_size,
@ -875,22 +866,16 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.LayerNorm(channels)
self.deriv_balancer2 = ActivationBalancer(channel_dim=1, self.pointwise_conv2 = nn.Conv1d(
min_positive=0.05,
max_positive=1.0)
self.activation = DoubleSwish()
self.pointwise_conv2 = ScaledConv1d(
channels, channels,
channels, channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=bias, bias=bias,
initial_scale=0.25
) )
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module. """Compute convolution module.
@ -907,14 +892,15 @@ class ConvolutionModule(nn.Module):
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time) x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=1) # (batch, channels, time) x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.deriv_balancer2(x)
x = self.activation(x) x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
@ -922,197 +908,13 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
class Identity(torch.nn.Module): class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x """Return Swich activation function."""
return x * torch.sigmoid(x)
class RandomCombine(torch.nn.Module): def identity(x):
""" return x
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(self, num_inputs: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0) -> None:
"""
Args:
num_inputs: The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
final_weight: The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob: The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev: A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers,
or combinations of layers, to use, is conceptually as follows.
With probability `pure_prob`:
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else:
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super(RandomCombine, self).__init__()
assert pure_prob >= 0 and pure_prob <= 1
assert final_weight > 0 and final_weight < 1
assert num_inputs >= 1
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev= stddev
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
def forward(self, inputs: Sequence[Tensor],
warmup_mode: bool) -> Tensor:
"""
Forward function.
Args:
inputs: a list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
a Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not (self.training and warmup_mode):
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames,
num_channels,
num_inputs))
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
num_frames)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
if __name__ == "__main__":
# for testing only...
print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor:
"""
Return a tensor of random weights, of shape (num_frames, self.num_inputs),
Args:
dtype: the data-type desired for the answer, e.g. float, double
device: the device needed for the answer
num_frames: the number of sets of weights desired
Returns: a tensor of shape (num_frames, self.num_inputs), such that
ans.sum(dim=1) is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m)
def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
"""
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
Args:
dtype: the data-type desired for the answer, e.g. float, double
device: the device needed for the answer
num_frames: the number of sets of weights desired
Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
indexes = torch.where(torch.rand(num_frames, device=device) < final_prob,
final, nonfinal)
ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype)
return ans
def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
"""
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
Args:
dtype: the data-type desired for the answer, e.g. float, double
device: the device needed for the answer
num_frames: the number of sets of weights desired
Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that
sum to one over the second axis, i.e. ans.sum(dim=1) is all ones.
"""
logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev
logprobs[:,-1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
num_inputs = 3
num_channels = 50
m = RandomCombine(num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev)
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
y = m(x, True)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
if __name__ == '__main__':
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
feature_dim = 50
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup_mode=True)

View File

@ -17,9 +17,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from typing import Optional
from subsampling import ScaledConv1d
class Decoder(nn.Module): class Decoder(nn.Module):
@ -55,7 +52,7 @@ class Decoder(nn.Module):
1 means bigram; 2 means trigram. n means (n+1)-gram. 1 means bigram; 2 means trigram. n means (n+1)-gram.
""" """
super().__init__() super().__init__()
self.embedding = ScaledEmbedding( self.embedding = nn.Embedding(
num_embeddings=vocab_size, num_embeddings=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
padding_idx=blank_id, padding_idx=blank_id,
@ -65,7 +62,7 @@ class Decoder(nn.Module):
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size
if context_size > 1: if context_size > 1:
self.conv = ScaledConv1d( self.conv = nn.Conv1d(
in_channels=embedding_dim, in_channels=embedding_dim,
out_channels=embedding_dim, out_channels=embedding_dim,
kernel_size=context_size, kernel_size=context_size,
@ -85,7 +82,6 @@ class Decoder(nn.Module):
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
y = y.to(torch.int64)
embedding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
@ -100,139 +96,3 @@ class Decoder(nn.Module):
embedding_out = self.conv(embedding_out) embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out return embedding_out
class ScaledEmbedding(nn.Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
.. note::
Keep in mind that only a limited number of optimizers support
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
.. note::
With :attr:`padding_idx` set, the embedding vector at
:attr:`padding_idx` is initialized to all zeros. However, note that this
vector can be modified afterwards, e.g., using a customized
initialization method, and thus changing the vector used to pad the
output. The gradient for this vector from :class:`~torch.nn.Embedding`
is always zero.
Examples::
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
"""
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
'scale_grad_by_freq', 'sparse']
num_embeddings: int
embedding_dim: int
padding_idx: int
scale_grad_by_freq: bool
weight: Tensor
sparse: bool
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
scale_speed: float = 5.0) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq
self.scale_speed = scale_speed
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.05)
nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed)
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
scale = (self.scale * self.scale_speed).exp()
if input.numel() < self.num_embeddings:
return F.embedding(
input, self.weight, self.padding_idx,
None, 2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq, self.sparse) * scale
else:
return F.embedding(
input, self.weight * scale, self.padding_idx,
None, 2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq, self.sparse)
def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}'
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
if self.sparse is not False:
s += ', sparse=True'
return s.format(**self.__dict__)

View File

@ -22,7 +22,7 @@ import torch.nn as nn
class EncoderInterface(nn.Module): class EncoderInterface(nn.Module):
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:

View File

@ -16,7 +16,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from subsampling import ScaledLinear
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int): def __init__(self, input_dim: int, output_dim: int):
@ -24,7 +24,7 @@ class Joiner(nn.Module):
self.input_dim = input_dim self.input_dim = input_dim
self.output_dim = output_dim self.output_dim = output_dim
self.output_linear = ScaledLinear(input_dim, output_dim) self.output_linear = nn.Linear(input_dim, output_dim)
def forward( def forward(
self, self,

View File

@ -65,7 +65,6 @@ class Transducer(nn.Module):
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
modified_transducer_prob: float = 0.0, modified_transducer_prob: float = 0.0,
warmup_mode: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -88,7 +87,7 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode) encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network

View File

@ -111,8 +111,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
# was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. default="transducer_stateless/exp",
default="transducer_stateless/randcombine1_expscale3_rework2d",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -223,7 +222,6 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
"warmup_minibatches": 3000, # use warmup mode for 3k minibatches.
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "encoder_out_dim": 512,
@ -381,7 +379,6 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
is_warmup_mode: bool = False
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -418,7 +415,6 @@ def compute_loss(
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
modified_transducer_prob=params.modified_transducer_prob, modified_transducer_prob=params.modified_transducer_prob,
warmup_mode=is_warmup_mode
) )
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -455,7 +451,6 @@ def compute_validation_loss(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=False, is_training=False,
is_warmup_mode=False
) )
assert loss.requires_grad is False assert loss.requires_grad is False
tot_loss = tot_loss + loss_info tot_loss = tot_loss + loss_info
@ -517,7 +512,6 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
is_warmup_mode=(params.batch_idx_train<params.warmup_minibatches)
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -750,7 +744,6 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
is_warmup_mode=False
) )
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)

View File

@ -21,7 +21,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
@ -106,7 +106,7 @@ class Transformer(EncoderInterface):
# TODO(fangjun): remove dropout # TODO(fangjun): remove dropout
self.encoder_output_layer = nn.Sequential( self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
) )
def forward( def forward(