mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
update dataset with aishell 2
This commit is contained in:
parent
8afb0d647f
commit
639feab4df
@ -19,14 +19,14 @@
|
|||||||
"optimizer": {
|
"optimizer": {
|
||||||
"type": "Adam",
|
"type": "Adam",
|
||||||
"params": {
|
"params": {
|
||||||
"lr": 1e-4
|
"lr": 5e-4
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scheduler": {
|
"scheduler": {
|
||||||
"type": "WarmupLR",
|
"type": "WarmupLR",
|
||||||
"params": {
|
"params": {
|
||||||
"warmup_min_lr": 0,
|
"warmup_min_lr": 0,
|
||||||
"warmup_max_lr": 1e-4,
|
"warmup_max_lr": 5e-4,
|
||||||
"warmup_num_steps": 100
|
"warmup_num_steps": 100
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -6,7 +6,7 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
|||||||
|
|
||||||
class EncoderProjector(nn.Module):
|
class EncoderProjector(nn.Module):
|
||||||
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
|
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
|
||||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=4):
|
def __init__(self, encoder_dim, llm_dim, downsample_rate=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.downsample_rate = downsample_rate
|
self.downsample_rate = downsample_rate
|
||||||
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
||||||
@ -140,13 +140,16 @@ class SPEECH_LLM(nn.Module):
|
|||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
#print("input_ids", input_ids, input_ids.shape)
|
# print("input_ids", input_ids, input_ids.shape)
|
||||||
#print("labels", labels, labels.shape)
|
# print("labels", labels, labels.shape)
|
||||||
|
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
#print("labels", labels, labels.shape)
|
# print("labels", labels, labels.shape)
|
||||||
#print("speech_features", speech_features.shape)
|
# print("speech_features", speech_features.shape, speech_features)
|
||||||
|
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
||||||
|
# input()
|
||||||
|
|
||||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -281,3 +281,40 @@ class MultiDataset:
|
|||||||
return {
|
return {
|
||||||
"aishell_test": aishell_test_cuts,
|
"aishell_test": aishell_test_cuts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# aishell 2
|
||||||
|
def aishell2_train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get multidataset train cuts")
|
||||||
|
|
||||||
|
# AISHELL-2
|
||||||
|
logging.info("Loading Aishell-2 in lazy mode")
|
||||||
|
aishell_2_cuts = load_manifest_lazy(
|
||||||
|
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
return aishell_2_cuts
|
||||||
|
|
||||||
|
def aishell2_dev_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get multidataset dev cuts")
|
||||||
|
|
||||||
|
# AISHELL-2
|
||||||
|
logging.info("Loading Aishell-2 set in lazy mode")
|
||||||
|
aishell2_dev_cuts = load_manifest_lazy(
|
||||||
|
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
return aishell2_dev_cuts
|
||||||
|
|
||||||
|
def aishell2_test_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get multidataset test cuts")
|
||||||
|
|
||||||
|
# AISHELL-2
|
||||||
|
logging.info("Loading Aishell-2 set in lazy mode")
|
||||||
|
aishell2_test_cuts = load_manifest_lazy(
|
||||||
|
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"aishell2_test": aishell2_test_cuts,
|
||||||
|
}
|
@ -820,7 +820,8 @@ def run(rank, world_size, args):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# train_cuts = multi_dataset.train_cuts()
|
# train_cuts = multi_dataset.train_cuts()
|
||||||
train_cuts = multi_dataset.aishell_train_cuts()
|
# train_cuts = multi_dataset.aishell_train_cuts()
|
||||||
|
train_cuts = multi_dataset.aishell2_train_cuts()
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user