diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py index 6d35ed6a9..e505c0700 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py @@ -189,71 +189,6 @@ def data_collator(batch): "lang": lang, } -def data_collator_concate_items(batch, concat_items_num: int = 3): - """Concatenate *concat_items_num* consecutive dataset items into one. - - The function groups the incoming ``batch`` (a list of dataset items) - into non-overlapping chunks of *concat_items_num*. For each group it - concatenates the textual fields and speech codec tokens so that the - model generates one longer utterance instead of several short ones. - - Any remainder (when ``len(batch)`` is not divisible by - *concat_items_num*) is also kept as a smaller group. - """ - - grouped_speech_tokens, grouped_messages, grouped_durations = [], [], [] - grouped_ids, grouped_lang = [], [] - - # Iterate over the batch in strides of *concat_items_num* - for start_idx in range(0, len(batch), concat_items_num): - group = batch[start_idx : start_idx + concat_items_num] - if not group: - continue - - # 1) Speech tokens -------------------------------------------------- - # ``item['code']`` can be a list[int] or a 1-D tensor. Use the first - # element to decide how to concatenate. - first_code = group[0]["code"] - if isinstance(first_code, torch.Tensor): - concat_code = torch.cat([item["code"] for item in group], dim=0) - else: - # assume list / iterable of ints - concat_code = [] - for item in group: - concat_code.extend(item["code"]) - - # 2) Text ----------------------------------------------------------- - concat_text = "".join([item["text"] for item in group]) - - # 3) Build chat template messages ----------------------------------- - message_list_item = [ - { - "role": "user", - "content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}", - }, - {"role": "assistant", "content": concat_text}, - ] - - # 4) Misc meta fields ---------------------------------------------- - total_duration = sum(item["duration"] for item in group) - group_ids = [item.get("index", item.get("id")) for item in group] - language = group[0].get("language", "") - - # 5) Append to output lists ---------------------------------------- - grouped_speech_tokens.append(concat_code) - grouped_messages.append(message_list_item) - grouped_durations.append(total_duration) - grouped_ids.append(group_ids) - grouped_lang.append(language) - - return { - "speech_tokens": grouped_speech_tokens, - "messages": grouped_messages, - "durations": grouped_durations, - "ids": grouped_ids, - "lang": grouped_lang, - } - def data_collator_ultra_chat(batch): speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], [] for i, item in enumerate(batch): @@ -550,7 +485,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") model.to(device) - assert params.deepspeed and world_size > 1 + # assert params.deepspeed and world_size > 1 logging.info("Using DeepSpeed") model, optimizer, _, scheduler = deepspeed.initialize( args=params, model=model, model_parameters=model.parameters()