update dataset with aishell 2

This commit is contained in:
root 2024-06-07 07:49:38 +00:00 committed by Yuekai Zhang
parent 8afb0d647f
commit 639feab4df
4 changed files with 49 additions and 8 deletions

View File

@ -19,14 +19,14 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
"lr": 5e-4
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 1e-4,
"warmup_max_lr": 5e-4,
"warmup_num_steps": 100
}
},

View File

@ -6,7 +6,7 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module):
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
def __init__(self, encoder_dim, llm_dim, downsample_rate=4):
def __init__(self, encoder_dim, llm_dim, downsample_rate=1):
super().__init__()
self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
@ -140,13 +140,16 @@ class SPEECH_LLM(nn.Module):
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
#print("input_ids", input_ids, input_ids.shape)
#print("labels", labels, labels.shape)
# print("input_ids", input_ids, input_ids.shape)
# print("labels", labels, labels.shape)
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
#print("labels", labels, labels.shape)
#print("speech_features", speech_features.shape)
# print("labels", labels, labels.shape)
# print("speech_features", speech_features.shape, speech_features)
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
# input()
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
with torch.no_grad():

View File

@ -280,4 +280,41 @@ class MultiDataset:
return {
"aishell_test": aishell_test_cuts,
}
# aishell 2
def aishell2_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
return aishell_2_cuts
def aishell2_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
return aishell2_dev_cuts
def aishell2_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
)
return {
"aishell2_test": aishell2_test_cuts,
}

View File

@ -820,7 +820,8 @@ def run(rank, world_size, args):
return True
# train_cuts = multi_dataset.train_cuts()
train_cuts = multi_dataset.aishell_train_cuts()
# train_cuts = multi_dataset.aishell_train_cuts()
train_cuts = multi_dataset.aishell2_train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: