mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
use padding instead of trimming (suggested by @shylockasr)
use ctc compress (suggested by @shylockasr) fix revert revert revert
This commit is contained in:
parent
05e3094429
commit
34639d5249
3
egs/speech_llm/ASR_LLM/.gitignore
vendored
3
egs/speech_llm/ASR_LLM/.gitignore
vendored
@ -1 +1,4 @@
|
|||||||
models
|
models
|
||||||
|
train*.sh
|
||||||
|
decode*.sh
|
||||||
|
sync*.sh
|
||||||
|
@ -47,103 +47,13 @@ class MultiDataset:
|
|||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset train cuts")
|
logging.info("About to get multidataset train cuts")
|
||||||
|
|
||||||
# THCHS-30
|
|
||||||
logging.info("Loading THCHS-30 in lazy mode")
|
|
||||||
thchs_30_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-1
|
|
||||||
logging.info("Loading Aishell-1 in lazy mode")
|
|
||||||
aishell_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 in lazy mode")
|
|
||||||
aishell_2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-4
|
|
||||||
logging.info("Loading Aishell-4 in lazy mode")
|
|
||||||
aishell_4_L_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell_4_M_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell_4_S_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ST-CMDS
|
|
||||||
logging.info("Loading ST-CMDS in lazy mode")
|
|
||||||
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
|
|
||||||
|
|
||||||
# Primewords
|
|
||||||
logging.info("Loading Primewords in lazy mode")
|
|
||||||
primewords_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# MagicData
|
|
||||||
logging.info("Loading MagicData in lazy mode")
|
|
||||||
magicdata_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ali-Meeting
|
|
||||||
logging.info("Loading Ali-Meeting in lazy mode")
|
|
||||||
alimeeting_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# WeNetSpeech
|
# WeNetSpeech
|
||||||
logging.info("Loading WeNetSpeech in lazy mode")
|
logging.info("Loading WeNetSpeech in lazy mode")
|
||||||
wenetspeech_L_cuts = load_manifest_lazy(
|
wenetspeech_L_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
|
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
# KeSpeech
|
return wenetspeech_L_cuts
|
||||||
logging.info("Loading KeSpeech in lazy mode")
|
|
||||||
kespeech_1_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return CutSet.mux(
|
|
||||||
thchs_30_cuts,
|
|
||||||
aishell_cuts,
|
|
||||||
aishell_2_cuts,
|
|
||||||
aishell_4_L_cuts,
|
|
||||||
aishell_4_M_cuts,
|
|
||||||
aishell_4_S_cuts,
|
|
||||||
alimeeting_cuts,
|
|
||||||
stcmds_cuts,
|
|
||||||
primewords_cuts,
|
|
||||||
magicdata_cuts,
|
|
||||||
wenetspeech_L_cuts,
|
|
||||||
kespeech_1_cuts,
|
|
||||||
kespeech_2_cuts,
|
|
||||||
weights=[
|
|
||||||
len(thchs_30_cuts),
|
|
||||||
len(aishell_cuts),
|
|
||||||
len(aishell_2_cuts),
|
|
||||||
len(aishell_4_L_cuts),
|
|
||||||
len(aishell_4_M_cuts),
|
|
||||||
len(aishell_4_S_cuts),
|
|
||||||
len(alimeeting_cuts),
|
|
||||||
len(stcmds_cuts),
|
|
||||||
len(primewords_cuts),
|
|
||||||
len(magicdata_cuts),
|
|
||||||
len(wenetspeech_L_cuts),
|
|
||||||
len(kespeech_1_cuts),
|
|
||||||
len(kespeech_2_cuts),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def dev_cuts(self) -> CutSet:
|
def dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset dev cuts")
|
logging.info("About to get multidataset dev cuts")
|
||||||
|
@ -30,9 +30,11 @@ class EncoderProjector(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
batch_size, seq_len, feat_dim = x.size()
|
batch_size, seq_len, feat_dim = x.size()
|
||||||
num_frames_to_discard = seq_len % self.downsample_rate
|
num_padding_frames = (
|
||||||
if num_frames_to_discard > 0:
|
self.downsample_rate - seq_len % self.downsample_rate
|
||||||
x = x[:, :-num_frames_to_discard, :]
|
) % self.downsample_rate
|
||||||
|
if num_padding_frames > 0:
|
||||||
|
x = torch.nn.functional.pad(x, (0, 0, 0, num_padding_frames))
|
||||||
seq_len = x.size(1)
|
seq_len = x.size(1)
|
||||||
|
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
@ -62,6 +64,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_embed: nn.Module,
|
encoder_embed: nn.Module,
|
||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
|
ctc_output: nn.Module,
|
||||||
llm: nn.Module,
|
llm: nn.Module,
|
||||||
encoder_projector: nn.Module,
|
encoder_projector: nn.Module,
|
||||||
):
|
):
|
||||||
@ -230,6 +233,57 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
def ctc_compress(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
blank_id: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Remove frames from encoder_out where CTC argmax predicts blank.
|
||||||
|
Args:
|
||||||
|
encoder_out: Tensor of shape (N, T, C), encoder output.
|
||||||
|
encoder_out_lens: Tensor of shape (N,), lengths before padding.
|
||||||
|
blank_id: CTC blank token ID (default: 0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed CTC output of shape (N, T', C).
|
||||||
|
"""
|
||||||
|
# 1. Compute CTC argmax predictions
|
||||||
|
ctc_output = self.ctc_output(encoder_out)
|
||||||
|
ctc_preds = ctc_output.argmax(dim=-1)
|
||||||
|
|
||||||
|
# 2. Create non-blank, non-pad mask
|
||||||
|
padding_mask = make_pad_mask(encoder_out_lens)
|
||||||
|
non_blank_mask = (ctc_preds != blank_id) & (~padding_mask)
|
||||||
|
|
||||||
|
# 3. Compute lengths after compress
|
||||||
|
compressed_lens = non_blank_mask.sum(dim=1)
|
||||||
|
max_len = compressed_lens.max().item()
|
||||||
|
|
||||||
|
# 4. Pre-pad output
|
||||||
|
pad_lens_list = (
|
||||||
|
torch.full_like(
|
||||||
|
compressed_lens,
|
||||||
|
max_len,
|
||||||
|
device=ctc_output.device,
|
||||||
|
)
|
||||||
|
- compressed_lens
|
||||||
|
)
|
||||||
|
max_pad_len = int(pad_lens_list.max())
|
||||||
|
padded_ctc_output = torch.nn.functional.pad(ctc_output, [0, 0, 0, max_pad_len])
|
||||||
|
|
||||||
|
# 5. Create final mask
|
||||||
|
padding_mask = ~make_pad_mask(pad_lens_list)
|
||||||
|
total_mask = torch.concat([non_blank_mask, padding_mask], dim=1)
|
||||||
|
|
||||||
|
# 6. Apply mask and reshape
|
||||||
|
compressed_output = padded_ctc_output[total_mask].reshape(
|
||||||
|
ctc_output.shape[0], -1, ctc_output.shape[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
return compressed_output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
fbank: torch.Tensor,
|
fbank: torch.Tensor,
|
||||||
@ -238,7 +292,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
labels: torch.LongTensor,
|
labels: torch.LongTensor,
|
||||||
):
|
):
|
||||||
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
|
encoder_outs, encoder_out_lens = self.forward_encoder(fbank, fbank_lens)
|
||||||
|
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user