use padding instead of trimming (suggested by @shylockasr)

use ctc compress (suggested by @shylockasr)

fix

revert

revert

revert
This commit is contained in:
Yifan Yang 2025-06-03 21:45:47 +08:00 committed by yfyeung
parent 05e3094429
commit 34639d5249
3 changed files with 62 additions and 95 deletions

View File

@ -1 +1,4 @@
models models
train*.sh
decode*.sh
sync*.sh

View File

@ -47,103 +47,13 @@ class MultiDataset:
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts") logging.info("About to get multidataset train cuts")
# THCHS-30
logging.info("Loading THCHS-30 in lazy mode")
thchs_30_cuts = load_manifest_lazy(
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
)
# AISHELL-1
logging.info("Loading Aishell-1 in lazy mode")
aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
# AISHELL-4
logging.info("Loading Aishell-4 in lazy mode")
aishell_4_L_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
)
aishell_4_M_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
)
aishell_4_S_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
)
# ST-CMDS
logging.info("Loading ST-CMDS in lazy mode")
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
# Primewords
logging.info("Loading Primewords in lazy mode")
primewords_cuts = load_manifest_lazy(
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData in lazy mode")
magicdata_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting in lazy mode")
alimeeting_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
)
# WeNetSpeech # WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode") logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy( wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz" self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
) )
# KeSpeech return wenetspeech_L_cuts
logging.info("Loading KeSpeech in lazy mode")
kespeech_1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
)
kespeech_2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
)
return CutSet.mux(
thchs_30_cuts,
aishell_cuts,
aishell_2_cuts,
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
alimeeting_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
wenetspeech_L_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
weights=[
len(thchs_30_cuts),
len(aishell_cuts),
len(aishell_2_cuts),
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(alimeeting_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(wenetspeech_L_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
],
)
def dev_cuts(self) -> CutSet: def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts") logging.info("About to get multidataset dev cuts")

View File

@ -30,9 +30,11 @@ class EncoderProjector(nn.Module):
def forward(self, x): def forward(self, x):
batch_size, seq_len, feat_dim = x.size() batch_size, seq_len, feat_dim = x.size()
num_frames_to_discard = seq_len % self.downsample_rate num_padding_frames = (
if num_frames_to_discard > 0: self.downsample_rate - seq_len % self.downsample_rate
x = x[:, :-num_frames_to_discard, :] ) % self.downsample_rate
if num_padding_frames > 0:
x = torch.nn.functional.pad(x, (0, 0, 0, num_padding_frames))
seq_len = x.size(1) seq_len = x.size(1)
x = x.contiguous() x = x.contiguous()
@ -62,6 +64,7 @@ class SPEECH_LLM(nn.Module):
self, self,
encoder_embed: nn.Module, encoder_embed: nn.Module,
encoder: EncoderInterface, encoder: EncoderInterface,
ctc_output: nn.Module,
llm: nn.Module, llm: nn.Module,
encoder_projector: nn.Module, encoder_projector: nn.Module,
): ):
@ -230,6 +233,57 @@ class SPEECH_LLM(nn.Module):
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens
def ctc_compress(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
blank_id: int = 0,
) -> torch.Tensor:
"""
Remove frames from encoder_out where CTC argmax predicts blank.
Args:
encoder_out: Tensor of shape (N, T, C), encoder output.
encoder_out_lens: Tensor of shape (N,), lengths before padding.
blank_id: CTC blank token ID (default: 0).
Returns:
Compressed CTC output of shape (N, T', C).
"""
# 1. Compute CTC argmax predictions
ctc_output = self.ctc_output(encoder_out)
ctc_preds = ctc_output.argmax(dim=-1)
# 2. Create non-blank, non-pad mask
padding_mask = make_pad_mask(encoder_out_lens)
non_blank_mask = (ctc_preds != blank_id) & (~padding_mask)
# 3. Compute lengths after compress
compressed_lens = non_blank_mask.sum(dim=1)
max_len = compressed_lens.max().item()
# 4. Pre-pad output
pad_lens_list = (
torch.full_like(
compressed_lens,
max_len,
device=ctc_output.device,
)
- compressed_lens
)
max_pad_len = int(pad_lens_list.max())
padded_ctc_output = torch.nn.functional.pad(ctc_output, [0, 0, 0, max_pad_len])
# 5. Create final mask
padding_mask = ~make_pad_mask(pad_lens_list)
total_mask = torch.concat([non_blank_mask, padding_mask], dim=1)
# 6. Apply mask and reshape
compressed_output = padded_ctc_output[total_mask].reshape(
ctc_output.shape[0], -1, ctc_output.shape[2]
)
return compressed_output
def forward( def forward(
self, self,
fbank: torch.Tensor, fbank: torch.Tensor,
@ -238,7 +292,7 @@ class SPEECH_LLM(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
labels: torch.LongTensor, labels: torch.LongTensor,
): ):
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens) encoder_outs, encoder_out_lens = self.forward_encoder(fbank, fbank_lens)
speech_features = self.encoder_projector(encoder_outs) speech_features = self.encoder_projector(encoder_outs)