diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 07fe934ae..b68aced9f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -88,7 +88,7 @@ class Conformer(Transformer): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -112,7 +112,8 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask, + warmup_mode=warmup_mode) # (T, N, C) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -258,7 +259,6 @@ class ConformerEncoder(nn.Module): self.num_layers = num_layers num_channels = encoder_layer.d_model self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0) @@ -269,6 +269,7 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + warmup_mode: bool = False ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -300,7 +301,7 @@ class ConformerEncoder(nn.Module): if i in self.aux_layers: outputs.append(output) - output = self.combiner(outputs) + output = self.combiner(outputs, warmup_mode) return output @@ -946,17 +947,12 @@ class RandomCombine(torch.nn.Module): is a random combination of all the inputs; but which in test time will be just the last input. - All but the last input will have a linear transform before we - randomly combine them; these linear transforms will be initialzed - to the identity transform. - The idea is that the list of Tensors will be a list of outputs of multiple conformer layers. This has a similar effect as iterated loss. (See: DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER NETWORKS). """ def __init__(self, num_inputs: int, - num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0) -> None: @@ -965,7 +961,6 @@ class RandomCombine(torch.nn.Module): num_inputs: The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. - num_channels: The number of channels on the input, e.g. 512. final_weight: The amount of weight or probability we assign to the final layer when randomly choosing layers or when choosing continuous layer weights. @@ -991,8 +986,6 @@ class RandomCombine(torch.nn.Module): assert pure_prob >= 0 and pure_prob <= 1 assert final_weight > 0 and final_weight < 1 assert num_inputs >= 1 - self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1)]) self.num_inputs = num_inputs self.final_weight = final_weight @@ -1000,14 +993,10 @@ class RandomCombine(torch.nn.Module): self.stddev= stddev self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - self._reset_parameters() - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) - def forward(self, inputs: Sequence[Tensor]) -> Tensor: + def forward(self, inputs: Sequence[Tensor], + warmup_mode: bool) -> Tensor: """ Forward function. Args: @@ -1019,24 +1008,18 @@ class RandomCombine(torch.nn.Module): """ num_inputs = self.num_inputs assert len(inputs) == num_inputs - if not self.training: + if not (self.training and warmup_mode): return inputs[-1] # Shape of weights: (*, num_inputs) num_channels = inputs[0].shape[-1] num_frames = inputs[0].numel() // num_channels - mod_inputs = [] - for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) - mod_inputs.append(inputs[num_inputs - 1]) - - ndim = inputs[0].ndim # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, + num_channels, + num_inputs)) # weights: (num_frames, num_inputs) weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, @@ -1118,12 +1101,14 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") num_inputs = 3 num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels, - final_weight=final_weight, pure_prob=pure_prob, stddev=stddev) + m = RandomCombine(num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev) x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - y = m(x) + y = m(x, True) assert y.shape == x[0].shape assert torch.allclose(y, x[0]) # .. since actually all ones. diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py index 257facce4..b295ce94b 100644 --- a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py +++ b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py @@ -22,7 +22,7 @@ import torch.nn as nn class EncoderInterface(nn.Module): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -32,6 +32,8 @@ class EncoderInterface(nn.Module): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. + warmup_mode: for training only, if true then train in + "warmup mode" (use this for the first few thousand minibatches). Returns: Return a tuple containing two tensors: - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 17b5f63e5..a45f0e295 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -62,6 +62,7 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + warmup_mode: bool = False ) -> torch.Tensor: """ Args: @@ -82,7 +83,7 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) + encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index f2d89b099..6c318c242 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -203,6 +203,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 + "warmup_minibatches": 3000, # use warmup mode for 3k minibatches. # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -360,6 +361,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + is_warmup_mode: bool = False ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -391,7 +393,8 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model(x=feature, x_lens=feature_lens, y=y, + warmup_mode=is_warmup_mode) assert loss.requires_grad == is_training @@ -423,6 +426,7 @@ def compute_validation_loss( sp=sp, batch=batch, is_training=False, + is_warmup_mode=False ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -484,6 +488,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, + is_warmup_mode=(params.batch_idx_train