From 34639d52498266f4674e81b33bd56f4fc04ca2a7 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 3 Jun 2025 21:45:47 +0800 Subject: [PATCH] use padding instead of trimming (suggested by @shylockasr) use ctc compress (suggested by @shylockasr) fix revert revert revert --- egs/speech_llm/ASR_LLM/.gitignore | 3 + .../ASR_LLM/whisper_llm_zh/multi_dataset.py | 92 +------------------ .../ASR_LLM/zipformer_llm_zh/model.py | 62 ++++++++++++- 3 files changed, 62 insertions(+), 95 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/.gitignore b/egs/speech_llm/ASR_LLM/.gitignore index 604f0f2cf..72ea9549c 100644 --- a/egs/speech_llm/ASR_LLM/.gitignore +++ b/egs/speech_llm/ASR_LLM/.gitignore @@ -1 +1,4 @@ models +train*.sh +decode*.sh +sync*.sh diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py index d116857af..3c960c716 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -47,103 +47,13 @@ class MultiDataset: def train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") - # THCHS-30 - logging.info("Loading THCHS-30 in lazy mode") - thchs_30_cuts = load_manifest_lazy( - self.fbank_dir / "thchs_30_cuts_train.jsonl.gz" - ) - - # AISHELL-1 - logging.info("Loading Aishell-1 in lazy mode") - aishell_cuts = load_manifest_lazy( - self.fbank_dir / "aishell_cuts_train.jsonl.gz" - ) - - # AISHELL-2 - logging.info("Loading Aishell-2 in lazy mode") - aishell_2_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_train.jsonl.gz" - ) - - # AISHELL-4 - logging.info("Loading Aishell-4 in lazy mode") - aishell_4_L_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz" - ) - aishell_4_M_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz" - ) - aishell_4_S_cuts = load_manifest_lazy( - self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz" - ) - - # ST-CMDS - logging.info("Loading ST-CMDS in lazy mode") - stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz") - - # Primewords - logging.info("Loading Primewords in lazy mode") - primewords_cuts = load_manifest_lazy( - self.fbank_dir / "primewords_cuts_train.jsonl.gz" - ) - - # MagicData - logging.info("Loading MagicData in lazy mode") - magicdata_cuts = load_manifest_lazy( - self.fbank_dir / "magicdata_cuts_train.jsonl.gz" - ) - - # Ali-Meeting - logging.info("Loading Ali-Meeting in lazy mode") - alimeeting_cuts = load_manifest_lazy( - self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz" - ) - # WeNetSpeech logging.info("Loading WeNetSpeech in lazy mode") wenetspeech_L_cuts = load_manifest_lazy( self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz" ) - # KeSpeech - logging.info("Loading KeSpeech in lazy mode") - kespeech_1_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz" - ) - kespeech_2_cuts = load_manifest_lazy( - self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz" - ) - - return CutSet.mux( - thchs_30_cuts, - aishell_cuts, - aishell_2_cuts, - aishell_4_L_cuts, - aishell_4_M_cuts, - aishell_4_S_cuts, - alimeeting_cuts, - stcmds_cuts, - primewords_cuts, - magicdata_cuts, - wenetspeech_L_cuts, - kespeech_1_cuts, - kespeech_2_cuts, - weights=[ - len(thchs_30_cuts), - len(aishell_cuts), - len(aishell_2_cuts), - len(aishell_4_L_cuts), - len(aishell_4_M_cuts), - len(aishell_4_S_cuts), - len(alimeeting_cuts), - len(stcmds_cuts), - len(primewords_cuts), - len(magicdata_cuts), - len(wenetspeech_L_cuts), - len(kespeech_1_cuts), - len(kespeech_2_cuts), - ], - ) + return wenetspeech_L_cuts def dev_cuts(self) -> CutSet: logging.info("About to get multidataset dev cuts") diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py index 5f0d4b8e5..d585ec871 100644 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py @@ -30,9 +30,11 @@ class EncoderProjector(nn.Module): def forward(self, x): batch_size, seq_len, feat_dim = x.size() - num_frames_to_discard = seq_len % self.downsample_rate - if num_frames_to_discard > 0: - x = x[:, :-num_frames_to_discard, :] + num_padding_frames = ( + self.downsample_rate - seq_len % self.downsample_rate + ) % self.downsample_rate + if num_padding_frames > 0: + x = torch.nn.functional.pad(x, (0, 0, 0, num_padding_frames)) seq_len = x.size(1) x = x.contiguous() @@ -62,6 +64,7 @@ class SPEECH_LLM(nn.Module): self, encoder_embed: nn.Module, encoder: EncoderInterface, + ctc_output: nn.Module, llm: nn.Module, encoder_projector: nn.Module, ): @@ -230,6 +233,57 @@ class SPEECH_LLM(nn.Module): return encoder_out, encoder_out_lens + def ctc_compress( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + blank_id: int = 0, + ) -> torch.Tensor: + """ + Remove frames from encoder_out where CTC argmax predicts blank. + Args: + encoder_out: Tensor of shape (N, T, C), encoder output. + encoder_out_lens: Tensor of shape (N,), lengths before padding. + blank_id: CTC blank token ID (default: 0). + + Returns: + Compressed CTC output of shape (N, T', C). + """ + # 1. Compute CTC argmax predictions + ctc_output = self.ctc_output(encoder_out) + ctc_preds = ctc_output.argmax(dim=-1) + + # 2. Create non-blank, non-pad mask + padding_mask = make_pad_mask(encoder_out_lens) + non_blank_mask = (ctc_preds != blank_id) & (~padding_mask) + + # 3. Compute lengths after compress + compressed_lens = non_blank_mask.sum(dim=1) + max_len = compressed_lens.max().item() + + # 4. Pre-pad output + pad_lens_list = ( + torch.full_like( + compressed_lens, + max_len, + device=ctc_output.device, + ) + - compressed_lens + ) + max_pad_len = int(pad_lens_list.max()) + padded_ctc_output = torch.nn.functional.pad(ctc_output, [0, 0, 0, max_pad_len]) + + # 5. Create final mask + padding_mask = ~make_pad_mask(pad_lens_list) + total_mask = torch.concat([non_blank_mask, padding_mask], dim=1) + + # 6. Apply mask and reshape + compressed_output = padded_ctc_output[total_mask].reshape( + ctc_output.shape[0], -1, ctc_output.shape[2] + ) + + return compressed_output + def forward( self, fbank: torch.Tensor, @@ -238,7 +292,7 @@ class SPEECH_LLM(nn.Module): attention_mask: torch.Tensor, labels: torch.LongTensor, ): - encoder_outs, _ = self.forward_encoder(fbank, fbank_lens) + encoder_outs, encoder_out_lens = self.forward_encoder(fbank, fbank_lens) speech_features = self.encoder_projector(encoder_outs)