From 4a294303499cae15357bbbcb2108baa6ddb7e420 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 May 2025 01:31:21 +0000 Subject: [PATCH] add loss type --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 74 ++++- .../SPEECH2SPEECH/qwen_omni/model.py | 66 ++++ .../SPEECH2SPEECH/qwen_omni/server.py | 23 +- .../SPEECH2SPEECH/qwen_omni/train.py | 287 +++++++++++++----- 4 files changed, 367 insertions(+), 83 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index fd8070691..e92e90a2f 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -239,7 +239,8 @@ fi if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then log "stage 14: Client" - datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa) + datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa) + datasets=(openbookqa commoneval) for dataset in ${datasets[@]}; do # sd-qa should use usa split if [ $dataset == "sd-qa" ]; then @@ -250,17 +251,16 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then echo $dataset $split_name python3 ./qwen_omni/client.py \ --subset-name $dataset --split-name $split_name \ - --output-dir test_result + --output-dir result_adapter_librispeech_kl_div_qa_template done fi - if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then log "stage 15: Training Speech2Speech Model, adaptor only" exp_dir=./qwen_omni/exp_speech2text ngpu=2 torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ - --max-duration 600 \ + --max-duration 700 \ --enable-musan False \ --audio-key audio --text-key continuation \ --exp-dir $exp_dir \ @@ -271,7 +271,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then --deepspeed_config ./qwen_omni/ds_config_zero1.json \ --use-flash-attn True \ --dataset-format speech_continuation \ - --start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \ + --start-epoch 4 --pretrained-model-path $exp_dir/epoch-3/pytorch_model.bin \ --use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False fi @@ -321,3 +321,67 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ $train_cmd_args fi + + +if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then + log "stage 17: Server for adapter only speech continuation" + exp_dir=./qwen_omni/exp_speech2text + python3 ./qwen_omni/server.py \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --checkpoint-path $exp_dir/epoch-6/pytorch_model.bin \ + --use-flash-attn True \ + --enable-speech-output False \ + --use-lora False --prompt-template continuation +fi + +if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then + log "stage 18: Training kl-div Speech2Speech Model, adaptor only" + exp_dir=./qwen_omni/exp_speech2text_kl + ngpu=2 + torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ + --max-duration 700 \ + --enable-musan False \ + --audio-key audio --text-key continuation \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ + --on-the-fly-feats True \ + --deepspeed \ + --deepspeed_config ./qwen_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --dataset-format speech_continuation \ + --loss-type kl_div --dataset librispeech \ + --pretrained-model-path $exp_dir/checkpoint-1001/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-1001/sampler.pt \ + --use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False +fi + +if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then + log "stage 19: Server for kl loss" + exp_dir=./qwen_omni/exp_speech2text_kl + python3 ./qwen_omni/server.py \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \ + --use-flash-attn True \ + --enable-speech-output False \ + --use-lora False --prompt-template qa +fi + +if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then + log "stage 20: Training Speech2Speech Model, adaptor + lora, second stage" + exp_dir=./qwen_omni/exp_speech2text_kl_llm + pretrained_dir=./qwen_omni/exp_speech2text_kl + ngpu=2 + torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ + --max-duration 200 \ + --enable-musan False \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ + --deepspeed \ + --deepspeed_config ./qwen_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --pretrained-model-path $pretrained_dir/epoch-10/pytorch_model.bin \ + --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False --dataset-format vocalnet +fi diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py index a0efbd319..97484486d 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py @@ -64,6 +64,8 @@ class SPEECH_LLM(nn.Module): encoder_projector: nn.Module, codec_lm: nn.Module = None, codec_lm_padding_side: str = "left", + teacher_llm: nn.Module = None, + kl_temperature: float = 2.0, ): super().__init__() self.encoder = encoder @@ -92,6 +94,9 @@ class SPEECH_LLM(nn.Module): multidim_average="global", ignore_index=IGNORE_TOKEN_ID, ) + if teacher_llm is not None: + self.teacher_llm = teacher_llm + self.kl_temperature = kl_temperature def _merge_input_ids_with_speech_features( self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None @@ -256,6 +261,67 @@ class SPEECH_LLM(nn.Module): ) return model_outputs.loss, acc + def forward_kl_div( + self, + fbank: torch.Tensor = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + labels: torch.LongTensor = None, + teacher_input_ids: torch.LongTensor = None, + teacher_attention_mask: torch.Tensor = None, + teacher_labels: torch.LongTensor = None, + ): + encoder_outs = self.encoder(fbank) + + speech_features = self.encoder_projector(encoder_outs) + + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + + ( + inputs_embeds, + attention_mask, + labels, + _, + ) = self._merge_input_ids_with_speech_features( + speech_features, inputs_embeds, input_ids, attention_mask, labels + ) + + model_outputs = self.llm( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels + ) + + teacher_outputs = self.teacher_llm( + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + ) + + kl_loss = torch.nn.functional.kl_div( + torch.nn.functional.log_softmax( + model_outputs.logits[labels != -100] / self.kl_temperature, + dim=-1, + ), + torch.nn.functional.softmax( + teacher_outputs.logits[teacher_labels != -100] / self.kl_temperature, + dim=-1, + ), + reduction="batchmean", + ) + + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + teacher_preds = torch.argmax(teacher_outputs.logits, -1) + acc = compute_accuracy( + preds.detach()[:, :-1], + labels.detach()[:, 1:], + ignore_label=IGNORE_TOKEN_ID, + ) + acc_teacher = compute_accuracy( + teacher_preds.detach()[:, :-1], + teacher_labels.detach()[:, 1:], + ignore_label=IGNORE_TOKEN_ID, + ) + return kl_loss, acc, acc_teacher + def forward_with_speech_output( self, fbank: torch.Tensor = None, diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py index 2f06b923a..3c9122a09 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py @@ -21,6 +21,12 @@ def get_args(): default=None, help="Checkpoint name or path, default to %(default)r", ) + parser.add_argument( + "--prompt-template", + type=str, + default=None, + help="Prompt template", + ) add_model_arguments(parser) args = parser.parse_args() return args @@ -59,8 +65,23 @@ model, tokenizer = get_model(args) app = FastAPI() device = torch.device("cuda") +if args.prompt_template is None: + template = f"{DEFAULT_SPEECH_TOKEN}" +elif args.prompt_template == "qa": + template = f"Answer the following question:\n\n{DEFAULT_SPEECH_TOKEN}" +elif args.prompt_template == "continuation": + template = f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}" +elif args.prompt_template == "asr": + template = ( + f"Repeat the following text, without any explanation: {DEFAULT_SPEECH_TOKEN}" + ) +elif args.prompt_template == "mt": + template = f"Please translate the text to Chinese. Your response should only include the Chinese translation, without any additional words:\n\n{DEFAULT_SPEECH_TOKEN}" +else: + raise ValueError(f"Invalid prompt template: {args.prompt_template}") +print("Using template:", template) message = [ - {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, + {"role": "user", "content": template}, {"role": "assistant", "content": ""}, ] TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index a11ae4b76..81aac84e5 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -74,8 +74,8 @@ from transformers import ( from utils import ( # filter_uneven_sized_batch, AttributeDict, MetricsTracker, - get_rank, get_local_rank, + get_rank, get_world_size, setup_logger, str2bool, @@ -234,6 +234,21 @@ def get_parser(): default="slam_omni", help="The format of the dataset.", ) + + parser.add_argument( + "--dataset", + type=str, + default="multi_en", + help="The name of the dataset.", + ) + + parser.add_argument( + "--loss-type", + type=str, + default="ce", + help="The type of loss to use.", + ) + parser = deepspeed.add_config_arguments(parser) add_model_arguments(parser) @@ -335,6 +350,22 @@ def process_batch_vocalnet(batch: dict): return messages, answer_cosyvoice_speech_token +def process_batch_text_vocalnet(batch: dict): + pass + answers = batch["supervisions"]["text"] + answer_cosyvoice_speech_token = [ + cut.custom["speech_token"] for cut in batch["supervisions"]["cut"] + ] + messages = [] + for i in range(len(answers)): + message = [ + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": answers[i]}, + ] + messages.append(message) + return messages, answer_cosyvoice_speech_token + + def process_batch_speech_continuation(batch: dict): messages = [] for i in range(len(batch["supervisions"]["text"])): @@ -350,6 +381,131 @@ def process_batch_speech_continuation(batch: dict): return messages +def process_batch_text_continuation(batch: dict): + messages = [] + for i in range(len(batch["supervisions"]["text"])): + transcript = batch["supervisions"]["cut"][i].custom["text"] + message = [ + { + "role": "user", + "content": f"Continue the following text using less than 50 words:\n\n{transcript}{DEFAULT_SPEECH_TOKEN}", + }, + {"role": "assistant", "content": batch["supervisions"]["text"][i]}, + ] + messages.append(message) + return messages + + +def preprocess( + messages, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """Preprocesses the data for supervised fine-tuning.""" + texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + for i, msg in enumerate(messages): + texts.append( + tokenizer.apply_chat_template( + msg, + tokenize=True, + chat_template=TEMPLATE, + add_generation_prompt=False, + padding="longest", # FIX me change padding to longest + truncation=False, + ) + ) + if len(texts) != len(messages): + logging.warning(f"Remove too long text, {messages} ") + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] + else: + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] + input_ids = torch.tensor(texts, dtype=torch.int) + + target_ids = input_ids.clone() + target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID + # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID + # first get the indices of the tokens + mask_prompt = True + if mask_prompt: + default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) + mask_indices = torch.where(input_ids == default_speech_token_id) + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + # + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198 + # WAR: TODO FIXME check qwen3 + target_ids[row, : col + 6] = IGNORE_TOKEN_ID + attention_mask = input_ids.ne(tokenizer.pad_token_id) + return input_ids, attention_mask, target_ids + + +def preprocess_teacher( + messages, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """Preprocesses the data for supervised fine-tuning.""" + texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + for i, msg in enumerate(messages): + texts.append( + tokenizer.apply_chat_template( + msg, + tokenize=True, + chat_template=TEMPLATE, + add_generation_prompt=False, + padding="longest", # FIX me change padding to longest + truncation=False, + ) + ) + if len(texts) != len(messages): + logging.warning(f"Remove too long text, {messages} ") + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] + else: + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] + input_ids = torch.tensor(texts, dtype=torch.int) + + target_ids = input_ids.clone() + target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID + # mask all tokens before token_id with IGNORE_TOKEN_ID + # first get the indices of the tokens + mask_prompt = True + if mask_prompt: + default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) + mask_indices = torch.where(input_ids == default_speech_token_id) + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + # + 2 to skip: 'assistant', '\n' + # WAR: TODO FIXME check qwen3 + # THIS IS THE ONLY DIFFERENCE FROM preprocess + target_ids[row, : col + 6] = IGNORE_TOKEN_ID + target_ids[row, col] = default_speech_token_id + # remove default_speech_token_id from target_ids and input_ids + batch_size = target_ids.size(0) + + target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1) + input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1) + + attention_mask = input_ids.ne(tokenizer.pad_token_id) + return input_ids, attention_mask, target_ids + + def compute_loss( params: AttributeDict, tokenizer: AutoTokenizer, @@ -374,72 +530,6 @@ def compute_loss( Returns: Return a tuple of two elements. The first element is the loss tensor. """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. - - def preprocess( - messages, - tokenizer: transformers.PreTrainedTokenizer, - ) -> Dict: - """Preprocesses the data for supervised fine-tuning.""" - texts = [] - TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" - for i, msg in enumerate(messages): - texts.append( - tokenizer.apply_chat_template( - msg, - tokenize=True, - chat_template=TEMPLATE, - add_generation_prompt=False, - padding="longest", # FIX me change padding to longest - truncation=False, - ) - ) - if len(texts) != len(messages): - logging.warning(f"Remove too long text, {messages} ") - max_len_texts = max([len(text) for text in texts]) - if tokenizer.padding_side == "right": - texts = [ - text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) - for text in texts - ] - else: - texts = [ - [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text - for text in texts - ] - input_ids = torch.tensor(texts, dtype=torch.int) - - target_ids = input_ids.clone() - target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID - # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID - # first get the indices of the tokens - mask_prompt = True - if mask_prompt: - default_speech_token_id = tokenizer.convert_tokens_to_ids( - DEFAULT_SPEECH_TOKEN - ) - mask_indices = torch.where(input_ids == default_speech_token_id) - for i in range(mask_indices[0].size(0)): - row = mask_indices[0][i] - col = mask_indices[1][i] - # + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198 - # WAR: TODO FIXME check qwen3 - target_ids[row, : col + 6] = IGNORE_TOKEN_ID - - attention_mask = input_ids.ne(tokenizer.pad_token_id) - - return input_ids, attention_mask, target_ids - - # max_frames = params.max_duration * 1000 // params.frame_shift_ms - # allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - # batch = filter_uneven_sized_batch(batch, allowed_max_frames) - device = next(model.parameters()).device feature = batch["inputs"] @@ -452,8 +542,12 @@ def compute_loss( messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch) elif params.dataset_format == "vocalnet": messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch) + if params.loss_type == "kl_div": + messages_text = process_batch_text_vocalnet(batch) elif params.dataset_format == "speech_continuation": messages = process_batch_speech_continuation(batch) + if params.loss_type == "kl_div": + messages_text = process_batch_text_continuation(batch) else: raise ValueError(f"Unknown dataset format: {params.dataset_format}") @@ -464,12 +558,30 @@ def compute_loss( with torch.set_grad_enabled(is_training): if not params.enable_speech_output: - loss, acc = model( - fbank=feature, - input_ids=input_ids.to(device), - attention_mask=attention_mask.to(device), - labels=target_ids.to(device), - ) + if params.loss_type == "ce": + loss, acc = model( + fbank=feature, + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + labels=target_ids.to(device), + ) + elif params.loss_type == "kl_div": + ( + teacher_input_ids, + teacher_attention_mask, + teacher_target_ids, + ) = preprocess_teacher(messages_text, tokenizer) + loss, acc, acc_teacher = model.forward_kl_div( + fbank=feature, + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + labels=target_ids.to(device), + teacher_input_ids=teacher_input_ids.to(device), + teacher_attention_mask=teacher_attention_mask.to(device), + teacher_labels=teacher_target_ids.to(device), + ) + else: + raise ValueError(f"Unknown loss type: {params.loss_type}") else: ( text_loss, @@ -498,6 +610,8 @@ def compute_loss( info["acc"] = ( acc * info["frames"] ) # WAR: to avoid normalization by the number of frames + if params.loss_type == "kl_div": + info["acc_teacher"] = acc_teacher * info["frames"] if params.enable_speech_output: info["codec_acc"] = codec_acc * info["frames"] info["codec_topk_acc"] = codec_topk_acc * info["frames"] @@ -820,6 +934,17 @@ def run(rank, world_size, args): codec_lm.config.mask_token_id = codec_vocab_size - 4 else: codec_lm = None + if params.loss_type == "kl_div": + teacher_llm = AutoModelForCausalLM.from_pretrained( + params.llm_path_or_name, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype, + ) + for name, param in teacher_llm.named_parameters(): + param.requires_grad = False + teacher_llm.eval() + else: + teacher_llm = None model = SPEECH_LLM( speech_encoder, @@ -827,6 +952,7 @@ def run(rank, world_size, args): encoder_projector, codec_lm, codec_lm_padding_side="left" if params.use_flash_attn else "right", + teacher_llm=teacher_llm, ) if params.pretrained_model_path: @@ -834,7 +960,9 @@ def run(rank, world_size, args): missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) # set params.batch_idx_train according to the checkpoint name if "checkpoint-" in params.pretrained_model_path: - params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1]) + params.batch_idx_train = int( + params.pretrained_model_path.split("-")[-1].split("/")[0] + ) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -893,9 +1021,14 @@ def run(rank, world_size, args): train_cuts = data_module.train_cuts_en_vocalnet() valid_cuts = data_module.valid_cuts_en_vocalnet() elif params.dataset_format == "speech_continuation": - train_cuts = data_module.train_cuts_ultravox() - # train_cuts = data_module.train_cuts_gigaspeech() - # train_cuts = data_module.train_cuts_librispeech() + if params.dataset == "multi_en": + train_cuts = data_module.train_cuts_ultravox() + elif params.dataset == "librispeech": + train_cuts = data_module.train_cuts_librispeech() + elif params.dataset == "gigaspeech": + train_cuts = data_module.train_cuts_gigaspeech() + else: + raise ValueError(f"Unknown dataset: {params.dataset}") valid_cuts = data_module.valid_cuts_ultravox() else: raise ValueError(f"Unknown dataset format: {params.dataset_format}")