diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py index d4e6208b3..54787a126 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py @@ -493,6 +493,8 @@ class AlignmentAttentionModule(nn.Module): embed_dim=pos_dim, dropout_rate=0.15 ) + self.dropout = nn.Dropout(p=0.5) + def forward( self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor ) -> Tensor: diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py index 4fcfc002c..ebade0954 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py @@ -77,12 +77,13 @@ class Joiner(nn.Module): if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + + if apply_attn: + # print(attn_encoder_out) + logit = encoder_out + decoder_out + attn_encoder_out else: - if apply_attn: - logit = encoder_out + decoder_out + attn_encoder_out - else: - # logging.info("disabling cross attn mdl") - logit = encoder_out + decoder_out + logging.info("disabling cross attn mdl") + logit = encoder_out + decoder_out logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index b8a5a8b79..cb863d30d 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -84,6 +84,7 @@ class AsrModel(nn.Module): self.encoder_embed = encoder_embed self.encoder = encoder + self.dropout = nn.Dropout(p=0.5) self.use_transducer = use_transducer if use_transducer: @@ -263,7 +264,7 @@ class AsrModel(nn.Module): lm=self.joiner.decoder_proj(decoder_out), ranges=ranges, ) - + am_pruned = self.dropout(am_pruned) # logits : [B, T, prune_range, vocab_size] # project_input=False since we applied the decoder's input projections @@ -274,7 +275,7 @@ class AsrModel(nn.Module): lm_pruned, None, encoder_out_lens, - apply_attn=batch_idx_train > self.params.warm_step, # True, # batch_idx_train > self.params.warm_step, + apply_attn=True, # batch_idx_train > self.params.warm_step, project_input=False, ) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/train.py b/egs/librispeech/ASR/zipformer_label_level_algn/train.py index 0d11a4bae..724916364 100755 --- a/egs/librispeech/ASR/zipformer_label_level_algn/train.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/train.py @@ -1236,14 +1236,14 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: