mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
support instruct s2s
This commit is contained in:
parent
9fff18edec
commit
dd858f0cd1
@ -413,6 +413,8 @@ class AsrDataModule:
|
|||||||
ultrachat_cuts = load_manifest_lazy(
|
ultrachat_cuts = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
|
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
|
||||||
|
ultrachat_cuts = ultrachat_cuts.resample(16000)
|
||||||
return CutSet.mux(
|
return CutSet.mux(
|
||||||
VoiceAssistant_cuts,
|
VoiceAssistant_cuts,
|
||||||
ultrachat_cuts,
|
ultrachat_cuts,
|
||||||
@ -427,6 +429,7 @@ class AsrDataModule:
|
|||||||
VoiceAssistant_cuts = load_manifest_lazy(
|
VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
|
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
|
||||||
return VoiceAssistant_cuts
|
return VoiceAssistant_cuts
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@ -435,6 +438,7 @@ class AsrDataModule:
|
|||||||
VoiceAssistant_cuts = load_manifest_lazy(
|
VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
|
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000)
|
||||||
return {"test": VoiceAssistant_cuts}
|
return {"test": VoiceAssistant_cuts}
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@ -482,36 +486,36 @@ class AsrDataModule:
|
|||||||
|
|
||||||
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||||
librispeech_clean_100,
|
librispeech_clean_100,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
|
|
||||||
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||||
librispeech_other,
|
librispeech_other,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
|
|
||||||
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||||
librispeech_clean_360,
|
librispeech_clean_360,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
|
|
||||||
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
||||||
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
|
gigaspeech, audio_key="audio", text_key="text"
|
||||||
)
|
)
|
||||||
|
|
||||||
people_speech_clean_cuts = CutSet.from_huggingface_dataset(
|
people_speech_clean_cuts = CutSet.from_huggingface_dataset(
|
||||||
people_speech_clean,
|
people_speech_clean,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
|
|
||||||
people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
|
people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
|
||||||
people_speech_dirty_sa,
|
people_speech_dirty_sa,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
|
|
||||||
return CutSet.mux(
|
return CutSet.mux(
|
||||||
@ -540,8 +544,8 @@ class AsrDataModule:
|
|||||||
)
|
)
|
||||||
librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
|
librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
|
||||||
librispeech_clean_valid,
|
librispeech_clean_valid,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
return librispeech_clean_valid_cuts
|
return librispeech_clean_valid_cuts
|
||||||
|
|
||||||
@ -567,20 +571,20 @@ class AsrDataModule:
|
|||||||
|
|
||||||
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||||
librispeech_clean_100,
|
librispeech_clean_100,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
|
|
||||||
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||||
librispeech_other,
|
librispeech_other,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
|
|
||||||
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||||
librispeech_clean_360,
|
librispeech_clean_360,
|
||||||
audio_key=self.args.audio_key,
|
audio_key="audio",
|
||||||
text_key=self.args.text_key,
|
text_key="text",
|
||||||
)
|
)
|
||||||
|
|
||||||
return CutSet.mux(
|
return CutSet.mux(
|
||||||
@ -603,7 +607,148 @@ class AsrDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
||||||
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
|
gigaspeech, audio_key="audio", text_key="text"
|
||||||
)
|
)
|
||||||
|
|
||||||
return gigaspeech_cuts
|
return gigaspeech_cuts
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts_instruct_s2s(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
if self.args.huggingface_dataset_path_or_name is not None:
|
||||||
|
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
|
||||||
|
else:
|
||||||
|
data_path = "yuekai/InstructS2S-200K"
|
||||||
|
# 148_688
|
||||||
|
instruct_s2s_train = load_dataset(
|
||||||
|
data_path, split="train", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
instruct_s2s_train,
|
||||||
|
audio_key="question_audio",
|
||||||
|
text_key="answer",
|
||||||
|
)
|
||||||
|
|
||||||
|
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
|
||||||
|
|
||||||
|
return instruct_s2s_train_cuts
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts_en_speech2speech(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
|
||||||
|
)
|
||||||
|
ultrachat_cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.huggingface_dataset_path_or_name is not None:
|
||||||
|
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
|
||||||
|
else:
|
||||||
|
data_path = "yuekai/InstructS2S-200K"
|
||||||
|
# 148_688
|
||||||
|
instruct_s2s_train = load_dataset(
|
||||||
|
data_path, split="train", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
instruct_s2s_train,
|
||||||
|
audio_key="question_audio",
|
||||||
|
text_key="answer",
|
||||||
|
)
|
||||||
|
|
||||||
|
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
|
||||||
|
|
||||||
|
|
||||||
|
return CutSet.mux(
|
||||||
|
VoiceAssistant_cuts,
|
||||||
|
ultrachat_cuts,
|
||||||
|
instruct_s2s_train_cuts,
|
||||||
|
weights=[
|
||||||
|
len(VoiceAssistant_cuts),
|
||||||
|
len(ultrachat_cuts),
|
||||||
|
423_000,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts_en_speech2speech_librispeech(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz"
|
||||||
|
)
|
||||||
|
ultrachat_cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.huggingface_dataset_path_or_name is not None:
|
||||||
|
data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K"
|
||||||
|
else:
|
||||||
|
data_path = "yuekai/InstructS2S-200K"
|
||||||
|
# 148_688
|
||||||
|
instruct_s2s_train = load_dataset(
|
||||||
|
data_path, split="train", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
instruct_s2s_train_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
instruct_s2s_train,
|
||||||
|
audio_key="question_audio",
|
||||||
|
text_key="answer",
|
||||||
|
)
|
||||||
|
|
||||||
|
instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000)
|
||||||
|
|
||||||
|
if self.args.huggingface_dataset_path_or_name is not None:
|
||||||
|
librispeech_path = self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
|
||||||
|
else:
|
||||||
|
librispeech_path = "fixie-ai/librispeech_asr"
|
||||||
|
# 148_688
|
||||||
|
librispeech_other = load_dataset(
|
||||||
|
librispeech_path, "other", split="train.500", streaming=True
|
||||||
|
)
|
||||||
|
# 104_014
|
||||||
|
librispeech_clean_360 = load_dataset(
|
||||||
|
librispeech_path, "clean", split="train.360", streaming=True
|
||||||
|
)
|
||||||
|
# 28_539
|
||||||
|
librispeech_clean_100 = load_dataset(
|
||||||
|
librispeech_path, "clean", split="train.100", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_clean_100,
|
||||||
|
audio_key="audio",
|
||||||
|
text_key="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_other,
|
||||||
|
audio_key="audio",
|
||||||
|
text_key="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_clean_360,
|
||||||
|
audio_key="audio",
|
||||||
|
text_key="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
return CutSet.mux(
|
||||||
|
librispeech_other_cuts,
|
||||||
|
VoiceAssistant_cuts,
|
||||||
|
ultrachat_cuts,
|
||||||
|
librispeech_clean_360_cuts,
|
||||||
|
instruct_s2s_train_cuts,
|
||||||
|
librispeech_clean_100_cuts,
|
||||||
|
weights=[
|
||||||
|
148688,
|
||||||
|
len(VoiceAssistant_cuts),
|
||||||
|
len(ultrachat_cuts),
|
||||||
|
104014,
|
||||||
|
423_000,
|
||||||
|
28539,
|
||||||
|
],
|
||||||
|
)
|
||||||
@ -193,6 +193,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--last-stage-model-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="""The path to the last stage model if it is not None. Training will start from this model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sampler-state-dict-path",
|
"--sampler-state-dict-path",
|
||||||
type=str,
|
type=str,
|
||||||
@ -229,13 +236,6 @@ def get_parser():
|
|||||||
help="Whether to unfreeze speech adaptor during training.",
|
help="Whether to unfreeze speech adaptor during training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt-template",
|
|
||||||
type=str,
|
|
||||||
default="speech_qa",
|
|
||||||
help="The prompt template to use.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset",
|
"--dataset",
|
||||||
type=str,
|
type=str,
|
||||||
@ -300,7 +300,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
def extract_text_and_speech_token(
|
def extract_text_and_speech_token(
|
||||||
batch: dict,
|
batch: dict,
|
||||||
prompt_template: str,
|
|
||||||
enable_speech_output: bool
|
enable_speech_output: bool
|
||||||
) -> Tuple[List[Dict[str, str]], Optional[List[Any]]]:
|
) -> Tuple[List[Dict[str, str]], Optional[List[Any]]]:
|
||||||
"""
|
"""
|
||||||
@ -325,54 +324,54 @@ def extract_text_and_speech_token(
|
|||||||
answers = batch["supervisions"]["text"]
|
answers = batch["supervisions"]["text"]
|
||||||
batch_size = len(answers)
|
batch_size = len(answers)
|
||||||
|
|
||||||
if prompt_template == "speech_qa":
|
prompt_template_dict = {
|
||||||
for i in range(batch_size):
|
"speech_qa": f"{DEFAULT_SPEECH_TOKEN}",
|
||||||
message_list_item = []
|
"speech_continuation": f"Continue the following text using less than 50 words:\\n\\n{DEFAULT_SPEECH_TOKEN}",
|
||||||
if 'round' in batch["supervisions"]["cut"][i].custom:
|
"asr": f"Transcribe the following audio into text:\\n\\n{DEFAULT_SPEECH_TOKEN}",
|
||||||
# slam_omni format dataset
|
}
|
||||||
current_question_with_history = batch["supervisions"]["cut"][i].custom["question"]
|
|
||||||
total_round = batch["supervisions"]["cut"][i].custom["round"]
|
|
||||||
history_context = current_question_with_history.rsplit("<USER>:", 1)[0].strip()
|
|
||||||
if total_round > 1:
|
|
||||||
history_question_answer = history_context.split("USER:")
|
|
||||||
history_question_answer = [item for item in history_question_answer if item]
|
|
||||||
for j in range(total_round - 1):
|
|
||||||
question_answer = history_question_answer[j].split("ASSISTANT:")
|
|
||||||
message_list_item += [
|
|
||||||
{"role": "user", "content": question_answer[0].strip()},
|
|
||||||
{"role": "assistant", "content": question_answer[1].strip()},
|
|
||||||
]
|
|
||||||
message_list_item += [
|
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
|
||||||
{"role": "assistant", "content": answers[i]},
|
|
||||||
]
|
|
||||||
messages.append(message_list_item)
|
|
||||||
|
|
||||||
elif prompt_template == "speech_continuation":
|
for i in range(batch_size):
|
||||||
# speech_tokens remains None
|
# Initialize prompt_template with the original default.
|
||||||
for i in range(batch_size):
|
# The 'prompt_template' argument to the function seems unused if we determine it here.
|
||||||
message_list_item = [
|
# For now, I will proceed assuming the internal logic dictates the template.
|
||||||
{
|
# If the function argument `prompt_template` was meant to be the default, this logic would need adjustment.
|
||||||
"role": "user",
|
current_prompt_template = "speech_qa" # Default value for prompt_template for the current item
|
||||||
"content": f"Continue the following text using less than 50 words:\\n\\n{DEFAULT_SPEECH_TOKEN}",
|
target = answers[i]
|
||||||
},
|
message_list_item = []
|
||||||
{"role": "assistant", "content": answers[i]},
|
|
||||||
]
|
custom_data = batch["supervisions"]["cut"][i].custom
|
||||||
messages.append(message_list_item)
|
|
||||||
|
|
||||||
elif prompt_template == "asr":
|
if 'round' in custom_data:
|
||||||
# speech_tokens remains None
|
# slam_omni format dataset
|
||||||
for i in range(batch_size):
|
# For 'round' type, the current interaction's user prompt will use current_prompt_template ("speech_qa")
|
||||||
message_list_item = [
|
current_question_with_history = custom_data["question"]
|
||||||
{
|
total_round = custom_data["round"]
|
||||||
"role": "user",
|
history_context = current_question_with_history.rsplit("<USER>:", 1)[0].strip()
|
||||||
"content": f"Transcribe the following audio into text:\\n\\n{DEFAULT_SPEECH_TOKEN}",
|
if total_round > 1:
|
||||||
},
|
history_question_answer = history_context.split("USER:")
|
||||||
{"role": "assistant", "content": answers[i]},
|
history_question_answer = [item for item in history_question_answer if item]
|
||||||
]
|
for j in range(total_round - 1):
|
||||||
messages.append(message_list_item)
|
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||||
else:
|
message_list_item += [
|
||||||
raise ValueError(f"Unknown prompt template: {prompt_template}")
|
{"role": "user", "content": question_answer[0].strip()},
|
||||||
|
{"role": "assistant", "content": question_answer[1].strip()},
|
||||||
|
]
|
||||||
|
elif 'continuation' in custom_data:
|
||||||
|
# see https://huggingface.co/datasets/fixie-ai/librispeech_asr
|
||||||
|
ASR_PROBABILITY = 0.3
|
||||||
|
if random.random() < ASR_PROBABILITY:
|
||||||
|
current_prompt_template = "asr"
|
||||||
|
else:
|
||||||
|
current_prompt_template = "speech_continuation"
|
||||||
|
target = custom_data["continuation"]
|
||||||
|
else:
|
||||||
|
# single-round, speech2speech conversation data
|
||||||
|
pass
|
||||||
|
message_list_item += [
|
||||||
|
{"role": "user", "content": prompt_template_dict[current_prompt_template]},
|
||||||
|
{"role": "assistant", "content": target},
|
||||||
|
]
|
||||||
|
messages.append(message_list_item)
|
||||||
|
|
||||||
return messages, speech_tokens
|
return messages, speech_tokens
|
||||||
|
|
||||||
@ -428,14 +427,17 @@ def preprocess(
|
|||||||
|
|
||||||
def process_batch_text_continuation(batch: dict):
|
def process_batch_text_continuation(batch: dict):
|
||||||
messages = []
|
messages = []
|
||||||
for i in range(len(batch["supervisions"]["text"])):
|
transcripts = batch["supervisions"]["text"]
|
||||||
transcript = batch["supervisions"]["cut"][i].custom["text"]
|
continuations = [
|
||||||
|
cut.custom["continuation"] for cut in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
|
for i in range(len(transcripts)):
|
||||||
message = [
|
message = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"Continue the following text using less than 50 words:\n\n{transcript}{DEFAULT_SPEECH_TOKEN}",
|
"content": f"Continue the following text using less than 50 words:\n\n{transcripts[i]}{DEFAULT_SPEECH_TOKEN}",
|
||||||
},
|
},
|
||||||
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
|
{"role": "assistant", "content": continuations[i]},
|
||||||
]
|
]
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
return messages
|
return messages
|
||||||
@ -532,7 +534,7 @@ def compute_loss(
|
|||||||
|
|
||||||
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
|
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
|
||||||
messages, answer_cosyvoice_speech_token = extract_text_and_speech_token(
|
messages, answer_cosyvoice_speech_token = extract_text_and_speech_token(
|
||||||
batch, params.prompt_template, params.enable_speech_output
|
batch, params.enable_speech_output
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
||||||
@ -550,7 +552,6 @@ def compute_loss(
|
|||||||
labels=target_ids.to(device),
|
labels=target_ids.to(device),
|
||||||
)
|
)
|
||||||
elif params.loss_type == "kl_div":
|
elif params.loss_type == "kl_div":
|
||||||
assert params.prompt_template == "speech_continuation"
|
|
||||||
messages_text = process_batch_text_continuation(batch)
|
messages_text = process_batch_text_continuation(batch)
|
||||||
(
|
(
|
||||||
teacher_input_ids,
|
teacher_input_ids,
|
||||||
@ -942,15 +943,18 @@ def run(rank, world_size, args):
|
|||||||
teacher_llm=teacher_llm,
|
teacher_llm=teacher_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.pretrained_model_path:
|
if params.pretrained_model_path or params.last_stage_model_path:
|
||||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
if params.pretrained_model_path is None:
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
checkpoint = torch.load(params.last_stage_model_path, map_location="cpu")
|
||||||
# set params.batch_idx_train according to the checkpoint name
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||||
if "checkpoint-" in params.pretrained_model_path:
|
else:
|
||||||
params.batch_idx_train = int(
|
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||||
params.pretrained_model_path.split("-")[-1].split("/")[0]
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||||
)
|
# set params.batch_idx_train according to the checkpoint name
|
||||||
|
if "checkpoint-" in params.pretrained_model_path:
|
||||||
|
params.batch_idx_train = int(
|
||||||
|
params.pretrained_model_path.split("-")[-1].split("/")[0]
|
||||||
|
)
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -999,6 +1003,12 @@ def run(rank, world_size, args):
|
|||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
if "question" in c.custom:
|
||||||
|
if len(c.custom["question"]) > 1200:
|
||||||
|
# logging.warning(
|
||||||
|
# f"Exclude cut with ID {c.id} from training. question length: {len(c.custom['question'])}"
|
||||||
|
# )
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if params.dataset == "slam_omni_belle":
|
if params.dataset == "slam_omni_belle":
|
||||||
@ -1007,6 +1017,12 @@ def run(rank, world_size, args):
|
|||||||
elif params.dataset == "vocalnet_ultrachat_voiceassistant":
|
elif params.dataset == "vocalnet_ultrachat_voiceassistant":
|
||||||
train_cuts = data_module.train_cuts_en_vocalnet()
|
train_cuts = data_module.train_cuts_en_vocalnet()
|
||||||
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||||
|
elif params.dataset == "vocalnet_ultrachat_voiceassistant_instruct_s2s":
|
||||||
|
train_cuts = data_module.train_cuts_en_speech2speech()
|
||||||
|
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||||
|
elif params.dataset == "vocalnet_ultrachat_voiceassistant_instruct_s2s_librispeech":
|
||||||
|
train_cuts = data_module.train_cuts_en_speech2speech_librispeech()
|
||||||
|
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||||
elif params.dataset == "ultravox_multi_en":
|
elif params.dataset == "ultravox_multi_en":
|
||||||
train_cuts = data_module.train_cuts_ultravox()
|
train_cuts = data_module.train_cuts_ultravox()
|
||||||
valid_cuts = data_module.valid_cuts_ultravox()
|
valid_cuts = data_module.valid_cuts_ultravox()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user