mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
use padding instead of trimming (suggested by @shylockasr)
use ctc compress (suggested by @shylockasr) fix revert revert revert
This commit is contained in:
parent
05e3094429
commit
34639d5249
3
egs/speech_llm/ASR_LLM/.gitignore
vendored
3
egs/speech_llm/ASR_LLM/.gitignore
vendored
@ -1 +1,4 @@
|
||||
models
|
||||
train*.sh
|
||||
decode*.sh
|
||||
sync*.sh
|
||||
|
@ -47,103 +47,13 @@ class MultiDataset:
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# THCHS-30
|
||||
logging.info("Loading THCHS-30 in lazy mode")
|
||||
thchs_30_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-1
|
||||
logging.info("Loading Aishell-1 in lazy mode")
|
||||
aishell_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 in lazy mode")
|
||||
aishell_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-4
|
||||
logging.info("Loading Aishell-4 in lazy mode")
|
||||
aishell_4_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
|
||||
)
|
||||
aishell_4_M_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
|
||||
)
|
||||
aishell_4_S_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
|
||||
)
|
||||
|
||||
# ST-CMDS
|
||||
logging.info("Loading ST-CMDS in lazy mode")
|
||||
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
|
||||
|
||||
# Primewords
|
||||
logging.info("Loading Primewords in lazy mode")
|
||||
primewords_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData in lazy mode")
|
||||
magicdata_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting in lazy mode")
|
||||
alimeeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech in lazy mode")
|
||||
wenetspeech_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech in lazy mode")
|
||||
kespeech_1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
thchs_30_cuts,
|
||||
aishell_cuts,
|
||||
aishell_2_cuts,
|
||||
aishell_4_L_cuts,
|
||||
aishell_4_M_cuts,
|
||||
aishell_4_S_cuts,
|
||||
alimeeting_cuts,
|
||||
stcmds_cuts,
|
||||
primewords_cuts,
|
||||
magicdata_cuts,
|
||||
wenetspeech_L_cuts,
|
||||
kespeech_1_cuts,
|
||||
kespeech_2_cuts,
|
||||
weights=[
|
||||
len(thchs_30_cuts),
|
||||
len(aishell_cuts),
|
||||
len(aishell_2_cuts),
|
||||
len(aishell_4_L_cuts),
|
||||
len(aishell_4_M_cuts),
|
||||
len(aishell_4_S_cuts),
|
||||
len(alimeeting_cuts),
|
||||
len(stcmds_cuts),
|
||||
len(primewords_cuts),
|
||||
len(magicdata_cuts),
|
||||
len(wenetspeech_L_cuts),
|
||||
len(kespeech_1_cuts),
|
||||
len(kespeech_2_cuts),
|
||||
],
|
||||
)
|
||||
return wenetspeech_L_cuts
|
||||
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
@ -30,9 +30,11 @@ class EncoderProjector(nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
batch_size, seq_len, feat_dim = x.size()
|
||||
num_frames_to_discard = seq_len % self.downsample_rate
|
||||
if num_frames_to_discard > 0:
|
||||
x = x[:, :-num_frames_to_discard, :]
|
||||
num_padding_frames = (
|
||||
self.downsample_rate - seq_len % self.downsample_rate
|
||||
) % self.downsample_rate
|
||||
if num_padding_frames > 0:
|
||||
x = torch.nn.functional.pad(x, (0, 0, 0, num_padding_frames))
|
||||
seq_len = x.size(1)
|
||||
|
||||
x = x.contiguous()
|
||||
@ -62,6 +64,7 @@ class SPEECH_LLM(nn.Module):
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
ctc_output: nn.Module,
|
||||
llm: nn.Module,
|
||||
encoder_projector: nn.Module,
|
||||
):
|
||||
@ -230,6 +233,57 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def ctc_compress(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
blank_id: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Remove frames from encoder_out where CTC argmax predicts blank.
|
||||
Args:
|
||||
encoder_out: Tensor of shape (N, T, C), encoder output.
|
||||
encoder_out_lens: Tensor of shape (N,), lengths before padding.
|
||||
blank_id: CTC blank token ID (default: 0).
|
||||
|
||||
Returns:
|
||||
Compressed CTC output of shape (N, T', C).
|
||||
"""
|
||||
# 1. Compute CTC argmax predictions
|
||||
ctc_output = self.ctc_output(encoder_out)
|
||||
ctc_preds = ctc_output.argmax(dim=-1)
|
||||
|
||||
# 2. Create non-blank, non-pad mask
|
||||
padding_mask = make_pad_mask(encoder_out_lens)
|
||||
non_blank_mask = (ctc_preds != blank_id) & (~padding_mask)
|
||||
|
||||
# 3. Compute lengths after compress
|
||||
compressed_lens = non_blank_mask.sum(dim=1)
|
||||
max_len = compressed_lens.max().item()
|
||||
|
||||
# 4. Pre-pad output
|
||||
pad_lens_list = (
|
||||
torch.full_like(
|
||||
compressed_lens,
|
||||
max_len,
|
||||
device=ctc_output.device,
|
||||
)
|
||||
- compressed_lens
|
||||
)
|
||||
max_pad_len = int(pad_lens_list.max())
|
||||
padded_ctc_output = torch.nn.functional.pad(ctc_output, [0, 0, 0, max_pad_len])
|
||||
|
||||
# 5. Create final mask
|
||||
padding_mask = ~make_pad_mask(pad_lens_list)
|
||||
total_mask = torch.concat([non_blank_mask, padding_mask], dim=1)
|
||||
|
||||
# 6. Apply mask and reshape
|
||||
compressed_output = padded_ctc_output[total_mask].reshape(
|
||||
ctc_output.shape[0], -1, ctc_output.shape[2]
|
||||
)
|
||||
|
||||
return compressed_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fbank: torch.Tensor,
|
||||
@ -238,7 +292,7 @@ class SPEECH_LLM(nn.Module):
|
||||
attention_mask: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
):
|
||||
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
|
||||
encoder_outs, encoder_out_lens = self.forward_encoder(fbank, fbank_lens)
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user