diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5d5d38947..9582e34fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -907,9 +907,9 @@ def deprecated_greedy_search_batch_for_cross_attn( logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), - attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out), - None, - apply_attn=True, + attn_encoder_out if t < 0 else torch.zeros_like(current_encoder_out), + encoder_out_lens, + apply_attn=False, project_input=False, ) # logits'shape (batch_size, 1, 1, vocab_size) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py index 205819544..4fcfc002c 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + import torch import torch.nn as nn from alignment_attention_module import AlignmentAttentionModule @@ -34,6 +36,7 @@ class Joiner(nn.Module): self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) self.output_linear = nn.Linear(joiner_dim, vocab_size) + self.enable_attn = False def forward( self, @@ -64,7 +67,10 @@ class Joiner(nn.Module): decoder_out.shape, ) - if apply_attn and lengths is not None: + if apply_attn: + if not self.enable_attn: + self.enable_attn = True + logging.info("enabling ATTN!") attn_encoder_out = self.label_level_am_attention( encoder_out, decoder_out, lengths ) @@ -72,7 +78,11 @@ class Joiner(nn.Module): if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: - logit = encoder_out + decoder_out + attn_encoder_out + if apply_attn: + logit = encoder_out + decoder_out + attn_encoder_out + else: + # logging.info("disabling cross attn mdl") + logit = encoder_out + decoder_out logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index 53abdd21a..b8a5a8b79 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -24,12 +24,13 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos, make_pad_mask, AttributeDict class AsrModel(nn.Module): def __init__( self, + params: AttributeDict, encoder_embed: nn.Module, encoder: EncoderInterface, decoder: Optional[nn.Module] = None, @@ -79,6 +80,8 @@ class AsrModel(nn.Module): assert isinstance(encoder, EncoderInterface), type(encoder) + self.params = params + self.encoder_embed = encoder_embed self.encoder = encoder @@ -180,6 +183,7 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + batch_idx_train: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Transducer loss. Args: @@ -264,12 +268,13 @@ class AsrModel(nn.Module): # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). + # print(batch_idx_train) logits = self.joiner( am_pruned, lm_pruned, None, encoder_out_lens, - apply_attn=True, + apply_attn=batch_idx_train > self.params.warm_step, # True, # batch_idx_train > self.params.warm_step, project_input=False, ) @@ -293,6 +298,7 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + batch_idx_train: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -345,6 +351,7 @@ class AsrModel(nn.Module): prune_range=prune_range, am_scale=am_scale, lm_scale=lm_scale, + batch_idx_train=batch_idx_train, ) else: simple_loss = torch.empty(0) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/train.py b/egs/librispeech/ASR/zipformer_label_level_algn/train.py index cb0ea8ef7..0d11a4bae 100755 --- a/egs/librispeech/ASR/zipformer_label_level_algn/train.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/train.py @@ -622,6 +622,7 @@ def get_model(params: AttributeDict) -> nn.Module: joiner = None model = AsrModel( + params=params, encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, @@ -800,6 +801,7 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + batch_idx_train=batch_idx_train, ) loss = 0.0