diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index d90dd34e1..83bcc3f3e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -88,7 +88,7 @@ class Conformer(Transformer): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -97,6 +97,10 @@ class Conformer(Transformer): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim) @@ -113,7 +117,7 @@ class Conformer(Transformer): mask = make_pad_mask(lengths) x = self.encoder(x, pos_emb, src_key_padding_mask=mask, - warmup_mode=warmup_mode) # (T, N, C) + warmup=warmup) # (T, N, C) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -193,6 +197,8 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + position: float = 0.0 ) -> Tensor: """ Pass the input through the encoder layer. @@ -202,6 +208,11 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective activation of layers; if < 1.0, it's possible that + not all modules will be included. + position: the position of this module in the encoder stack (relates to + warmup); a value 0 <= position < 1.0. + Shape: src: (S, N, E). @@ -210,9 +221,9 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ - # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) + src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), + alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0)) # multi-headed self-attention module @@ -224,13 +235,16 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = src + self.dropout(src_att) + src = torch.add(src, self.dropout(src_att), + alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0)) # convolution module - src = src + self.dropout(self.conv_module(src)) + src = torch.add(src, self.dropout(self.conv_module(src)), + alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0)) # feed forward module - src = src + self.dropout(self.feed_forward(src)) + src = torch.add(src, self.dropout(self.feed_forward(src)), + alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0)) src = self.norm_final(self.balancer(src)) @@ -262,10 +276,6 @@ class ConformerEncoder(nn.Module): assert num_layers - 1 not in aux_layers self.num_layers = num_layers num_channels = encoder_layer.d_model - self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0) def forward( self, @@ -273,7 +283,7 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup_mode: bool = False + warmup: float = 1.0 ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -293,7 +303,7 @@ class ConformerEncoder(nn.Module): """ output = src - outputs = [] + num_layers = len(self.layers) for i, mod in enumerate(self.layers): output = mod( @@ -301,11 +311,10 @@ class ConformerEncoder(nn.Module): pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + position=(i / num_layers), ) - if i in self.aux_layers: - outputs.append(output) - output = self.combiner(outputs, warmup_mode) return output @@ -922,187 +931,9 @@ class Identity(torch.nn.Module): return x -class RandomCombine(torch.nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - def __init__(self, num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0) -> None: - """ - Args: - num_inputs: The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, - or combinations of layers, to use, is conceptually as follows. - With probability `pure_prob`: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super(RandomCombine, self).__init__() - assert pure_prob >= 0 and pure_prob <= 1 - assert final_weight > 0 and final_weight < 1 - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev= stddev - - self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - - - def forward(self, inputs: Sequence[Tensor], - warmup_mode: bool) -> Tensor: - """ - Forward function. - Args: - inputs: a list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - a Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not (self.training and warmup_mode): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, - num_frames) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - - def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: - """ - Return a tensor of random weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), such that - ans.sum(dim=1) is all ones. - - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) - - def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with - exactly one weight equal to 1.0 on each frame. - """ - - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) - - indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, - final, nonfinal) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) - return ans - - - def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that - sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. - """ - logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev - logprobs[:,-1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") - num_inputs = 3 - num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev) - - x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - - y = m(x, True) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. if __name__ == '__main__': - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) batch_size = 5 @@ -1110,4 +941,4 @@ if __name__ == '__main__': # Just make sure the forward pass runs. f = c(torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup_mode=True) + warmup=0.5) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index e83d18e3e..faaebc477 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,7 +66,7 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - warmup_mode: bool = False + warmup: float = 1.0, ) -> torch.Tensor: """ Args: @@ -87,6 +87,9 @@ class Transducer(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. Returns: Return the transducer loss. @@ -102,7 +105,7 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode=warmup_mode) + encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 096f93d77..d4a2e83d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -296,7 +296,7 @@ def get_params() -> AttributeDict: "embedding_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 4000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -454,7 +454,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - warmup_mode: bool = False + warmup: float = 1.0 ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -471,6 +471,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. """ device = model.device feature = batch["inputs"] @@ -493,10 +495,10 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, - warmup_mode=warmup_mode, + warmup=warmup, ) loss = (params.simple_loss_scale * simple_loss + - (pruned_loss * 0.0 if warmup_mode else pruned_loss)) + (pruned_loss * 0.0 if warmup < 1.0 else pruned_loss)) assert loss.requires_grad == is_training @@ -601,7 +603,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup_mode=(params.batch_idx_train < params.model_warm_step) + warmup=(params.batch_idx_train / params.model_warm_step) ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -855,7 +857,6 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup_mode=True # may use slightly more memory ) loss.backward() optimizer.step()