From 9f62a0296cd072083399f6862d1df6bee0134555 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 21:16:39 +0800 Subject: [PATCH] Revert transducer_stateless/ to state in upstream/master --- .../ASR/transducer_stateless/conformer.py | 396 +++++------------- .../ASR/transducer_stateless/decoder.py | 144 +------ .../transducer_stateless/encoder_interface.py | 2 +- .../ASR/transducer_stateless/joiner.py | 4 +- .../ASR/transducer_stateless/model.py | 3 +- .../ASR/transducer_stateless/train.py | 9 +- .../ASR/transducer_stateless/transformer.py | 4 +- 7 files changed, 108 insertions(+), 454 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index ae95d95b4..488c82386 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,8 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple, Sequence -from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -57,7 +56,6 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -82,13 +80,17 @@ class Conformer(Transformer): cnn_module_kernel, normalize_before, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before - + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False + self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -115,8 +117,10 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask, - warmup_mode=warmup_mode) # (T, N, C) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + + if self.normalize_before: + x = self.after_norm(x) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -154,41 +158,42 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), + nn.Linear(d_model, dim_feedforward), + Swish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), + nn.Linear(d_model, dim_feedforward), + Swish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + nn.Linear(dim_feedforward, d_model), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - self.norm_final = BasicNorm(d_model) + self.ff_scale = 0.5 - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_positive=6.0) + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) + self.normalize_before = normalize_before def forward( self, @@ -215,10 +220,19 @@ class ConformerEncoderLayer(nn.Module): """ # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) - + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) src_att = self.self_attn( src, src, @@ -227,15 +241,28 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = src + self.dropout(src_att) + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) # convolution module - src = src + self.dropout(self.conv_module(src)) + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) # feed forward module - src = src + self.dropout(self.feed_forward(src)) + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) - src = self.norm_final(self.balancer(src)) + if self.normalize_before: + src = self.norm_final(src) return src @@ -255,20 +282,12 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int, - aux_layers: Sequence[int]) -> None: + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) - self.aux_layers = set(aux_layers + [num_layers - 1]) - assert num_layers - 1 not in aux_layers self.num_layers = num_layers - num_channels = encoder_layer.d_model - self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0) def forward( self, @@ -276,7 +295,6 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup_mode: bool = False ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -296,19 +314,14 @@ class ConformerEncoder(nn.Module): """ output = src - outputs = [] - - for i, mod in enumerate(self.layers): + for mod in self.layers: output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) - if i in self.aux_layers: - outputs.append(output) - output = self.combiner(outputs, warmup_mode) return output @@ -331,6 +344,7 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model + self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -382,6 +396,7 @@ class RelPositionalEncoding(torch.nn.Module): """ self.extend_pe(x) + x = x * self.xscale pos_emb = self.pe[ :, self.pe.size(1) // 2 @@ -413,7 +428,6 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, - scale_speed: float = 5.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -424,29 +438,25 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.scale_speed = scale_speed - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() - def _pos_bias_u(self): - return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() - - def _pos_bias_v(self): - return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() - def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.05) - nn.init.normal_(self.pos_bias_v, std=0.05) + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) def forward( self, @@ -506,11 +516,11 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb, self.embed_dim, self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), + self.in_proj.weight, + self.in_proj.bias, self.dropout, - self.out_proj.get_weight(), - self.out_proj.get_bias(), + self.out_proj.weight, + self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -614,12 +624,13 @@ class RelPositionMultiheadAttention(nn.Module): assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -651,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -670,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -720,7 +729,7 @@ class RelPositionMultiheadAttention(nn.Module): ) key_padding_mask = key_padding_mask.to(torch.bool) - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) @@ -741,11 +750,11 @@ class RelPositionMultiheadAttention(nn.Module): p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - q_with_bias_u = (q + self._pos_bias_u()).transpose( + q_with_bias_u = (q + self.pos_bias_u).transpose( 1, 2 ) # (batch, head, time1, d_k) - q_with_bias_v = (q + self._pos_bias_v()).transpose( + q_with_bias_v = (q + self.pos_bias_v).transpose( 1, 2 ) # (batch, head, time1, d_k) @@ -765,7 +774,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = ( matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 @@ -840,7 +849,7 @@ class ConvolutionModule(nn.Module): # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = ScaledConv1d( + self.pointwise_conv1 = nn.Conv1d( channels, 2 * channels, kernel_size=1, @@ -848,25 +857,7 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.05, - max_positive=1.0) - - self.depthwise_conv = ScaledConv1d( + self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, @@ -875,22 +866,16 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - - self.deriv_balancer2 = ActivationBalancer(channel_dim=1, - min_positive=0.05, - max_positive=1.0) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, - initial_scale=0.25 ) + self.activation = Swish() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. @@ -907,14 +892,15 @@ class ConvolutionModule(nn.Module): # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) - x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) @@ -922,197 +908,13 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) -class Identity(torch.nn.Module): +class Swish(torch.nn.Module): + """Construct an Swish object.""" + def forward(self, x: Tensor) -> Tensor: - return x + """Return Swich activation function.""" + return x * torch.sigmoid(x) -class RandomCombine(torch.nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - def __init__(self, num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0) -> None: - """ - Args: - num_inputs: The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, - or combinations of layers, to use, is conceptually as follows. - With probability `pure_prob`: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super(RandomCombine, self).__init__() - assert pure_prob >= 0 and pure_prob <= 1 - assert final_weight > 0 and final_weight < 1 - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev= stddev - - self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - - - def forward(self, inputs: Sequence[Tensor], - warmup_mode: bool) -> Tensor: - """ - Forward function. - Args: - inputs: a list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - a Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not (self.training and warmup_mode): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, - num_frames) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - - def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: - """ - Return a tensor of random weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), such that - ans.sum(dim=1) is all ones. - - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) - - def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with - exactly one weight equal to 1.0 on each frame. - """ - - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) - - indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, - final, nonfinal) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) - return ans - - - def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that - sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. - """ - logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev - logprobs[:,-1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") - num_inputs = 3 - num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev) - - x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - - y = m(x, True) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. - - -if __name__ == '__main__': - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - feature_dim = 50 - c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c(torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup_mode=True) +def identity(x): + return x diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index db51fb1cd..b82fed37b 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -17,9 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional -from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -55,7 +52,7 @@ class Decoder(nn.Module): 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() - self.embedding = ScaledEmbedding( + self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, @@ -65,7 +62,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size if context_size > 1: - self.conv = ScaledConv1d( + self.conv = nn.Conv1d( in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=context_size, @@ -85,7 +82,6 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ - y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) @@ -100,139 +96,3 @@ class Decoder(nn.Module): embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) return embedding_out - - - -class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' - elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - - - - def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - scale = (self.scale * self.scale_speed).exp() - if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale - else: - return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) - - def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' - if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' - if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' - if self.sparse is not False: - s += ', sparse=True' - return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py index 3d218dcd0..257facce4 100644 --- a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py +++ b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py @@ -22,7 +22,7 @@ import torch.nn as nn class EncoderInterface(nn.Module): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 241f405b6..b0ba7fd83 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from subsampling import ScaledLinear + class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): @@ -24,7 +24,7 @@ class Joiner(nn.Module): self.input_dim = input_dim self.output_dim = output_dim - self.output_linear = ScaledLinear(input_dim, output_dim) + self.output_linear = nn.Linear(input_dim, output_dim) def forward( self, diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index fc16f2631..8281e1fb5 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -65,7 +65,6 @@ class Transducer(nn.Module): x_lens: torch.Tensor, y: k2.RaggedTensor, modified_transducer_prob: float = 0.0, - warmup_mode: bool = False ) -> torch.Tensor: """ Args: @@ -88,7 +87,7 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode) + encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index fa0410973..d6827c17c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -111,8 +111,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. - default="transducer_stateless/randcombine1_expscale3_rework2d", + default="transducer_stateless/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -223,7 +222,6 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 - "warmup_minibatches": 3000, # use warmup mode for 3k minibatches. # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -381,7 +379,6 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - is_warmup_mode: bool = False ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -418,7 +415,6 @@ def compute_loss( x_lens=feature_lens, y=y, modified_transducer_prob=params.modified_transducer_prob, - warmup_mode=is_warmup_mode ) assert loss.requires_grad == is_training @@ -455,7 +451,6 @@ def compute_validation_loss( sp=sp, batch=batch, is_training=False, - is_warmup_mode=False ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -517,7 +512,6 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - is_warmup_mode=(params.batch_idx_train