remove concat three items

This commit is contained in:
root 2025-06-03 00:18:21 -07:00
parent 4c0396f8f2
commit 5becf6927d

View File

@ -189,71 +189,6 @@ def data_collator(batch):
"lang": lang, "lang": lang,
} }
def data_collator_concate_items(batch, concat_items_num: int = 3):
"""Concatenate *concat_items_num* consecutive dataset items into one.
The function groups the incoming ``batch`` (a list of dataset items)
into non-overlapping chunks of *concat_items_num*. For each group it
concatenates the textual fields and speech codec tokens so that the
model generates one longer utterance instead of several short ones.
Any remainder (when ``len(batch)`` is not divisible by
*concat_items_num*) is also kept as a smaller group.
"""
grouped_speech_tokens, grouped_messages, grouped_durations = [], [], []
grouped_ids, grouped_lang = [], []
# Iterate over the batch in strides of *concat_items_num*
for start_idx in range(0, len(batch), concat_items_num):
group = batch[start_idx : start_idx + concat_items_num]
if not group:
continue
# 1) Speech tokens --------------------------------------------------
# ``item['code']`` can be a list[int] or a 1-D tensor. Use the first
# element to decide how to concatenate.
first_code = group[0]["code"]
if isinstance(first_code, torch.Tensor):
concat_code = torch.cat([item["code"] for item in group], dim=0)
else:
# assume list / iterable of ints
concat_code = []
for item in group:
concat_code.extend(item["code"])
# 2) Text -----------------------------------------------------------
concat_text = "".join([item["text"] for item in group])
# 3) Build chat template messages -----------------------------------
message_list_item = [
{
"role": "user",
"content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": concat_text},
]
# 4) Misc meta fields ----------------------------------------------
total_duration = sum(item["duration"] for item in group)
group_ids = [item.get("index", item.get("id")) for item in group]
language = group[0].get("language", "")
# 5) Append to output lists ----------------------------------------
grouped_speech_tokens.append(concat_code)
grouped_messages.append(message_list_item)
grouped_durations.append(total_duration)
grouped_ids.append(group_ids)
grouped_lang.append(language)
return {
"speech_tokens": grouped_speech_tokens,
"messages": grouped_messages,
"durations": grouped_durations,
"ids": grouped_ids,
"lang": grouped_lang,
}
def data_collator_ultra_chat(batch): def data_collator_ultra_chat(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], [] speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
@ -550,7 +485,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
model.to(device) model.to(device)
assert params.deepspeed and world_size > 1 # assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed") logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize( model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters() args=params, model=model, model_parameters=model.parameters()