mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
update dataset with aishell 2
This commit is contained in:
parent
8afb0d647f
commit
639feab4df
@ -19,14 +19,14 @@
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-4
|
||||
"lr": 5e-4
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 1e-4,
|
||||
"warmup_max_lr": 5e-4,
|
||||
"warmup_num_steps": 100
|
||||
}
|
||||
},
|
||||
|
@ -6,7 +6,7 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
|
||||
class EncoderProjector(nn.Module):
|
||||
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
|
||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=4):
|
||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=1):
|
||||
super().__init__()
|
||||
self.downsample_rate = downsample_rate
|
||||
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
||||
@ -140,13 +140,16 @@ class SPEECH_LLM(nn.Module):
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
#print("input_ids", input_ids, input_ids.shape)
|
||||
#print("labels", labels, labels.shape)
|
||||
# print("input_ids", input_ids, input_ids.shape)
|
||||
# print("labels", labels, labels.shape)
|
||||
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
#print("labels", labels, labels.shape)
|
||||
#print("speech_features", speech_features.shape)
|
||||
# print("labels", labels, labels.shape)
|
||||
# print("speech_features", speech_features.shape, speech_features)
|
||||
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
||||
# input()
|
||||
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||
with torch.no_grad():
|
||||
|
@ -280,4 +280,41 @@ class MultiDataset:
|
||||
|
||||
return {
|
||||
"aishell_test": aishell_test_cuts,
|
||||
}
|
||||
|
||||
|
||||
# aishell 2
|
||||
def aishell2_train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 in lazy mode")
|
||||
aishell_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
return aishell_2_cuts
|
||||
|
||||
def aishell2_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
return aishell2_dev_cuts
|
||||
|
||||
def aishell2_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
return {
|
||||
"aishell2_test": aishell2_test_cuts,
|
||||
}
|
@ -820,7 +820,8 @@ def run(rank, world_size, args):
|
||||
return True
|
||||
|
||||
# train_cuts = multi_dataset.train_cuts()
|
||||
train_cuts = multi_dataset.aishell_train_cuts()
|
||||
# train_cuts = multi_dataset.aishell_train_cuts()
|
||||
train_cuts = multi_dataset.aishell2_train_cuts()
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user