mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
remove concat three items
This commit is contained in:
parent
4c0396f8f2
commit
5becf6927d
@ -189,71 +189,6 @@ def data_collator(batch):
|
||||
"lang": lang,
|
||||
}
|
||||
|
||||
def data_collator_concate_items(batch, concat_items_num: int = 3):
|
||||
"""Concatenate *concat_items_num* consecutive dataset items into one.
|
||||
|
||||
The function groups the incoming ``batch`` (a list of dataset items)
|
||||
into non-overlapping chunks of *concat_items_num*. For each group it
|
||||
concatenates the textual fields and speech codec tokens so that the
|
||||
model generates one longer utterance instead of several short ones.
|
||||
|
||||
Any remainder (when ``len(batch)`` is not divisible by
|
||||
*concat_items_num*) is also kept as a smaller group.
|
||||
"""
|
||||
|
||||
grouped_speech_tokens, grouped_messages, grouped_durations = [], [], []
|
||||
grouped_ids, grouped_lang = [], []
|
||||
|
||||
# Iterate over the batch in strides of *concat_items_num*
|
||||
for start_idx in range(0, len(batch), concat_items_num):
|
||||
group = batch[start_idx : start_idx + concat_items_num]
|
||||
if not group:
|
||||
continue
|
||||
|
||||
# 1) Speech tokens --------------------------------------------------
|
||||
# ``item['code']`` can be a list[int] or a 1-D tensor. Use the first
|
||||
# element to decide how to concatenate.
|
||||
first_code = group[0]["code"]
|
||||
if isinstance(first_code, torch.Tensor):
|
||||
concat_code = torch.cat([item["code"] for item in group], dim=0)
|
||||
else:
|
||||
# assume list / iterable of ints
|
||||
concat_code = []
|
||||
for item in group:
|
||||
concat_code.extend(item["code"])
|
||||
|
||||
# 2) Text -----------------------------------------------------------
|
||||
concat_text = "".join([item["text"] for item in group])
|
||||
|
||||
# 3) Build chat template messages -----------------------------------
|
||||
message_list_item = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}",
|
||||
},
|
||||
{"role": "assistant", "content": concat_text},
|
||||
]
|
||||
|
||||
# 4) Misc meta fields ----------------------------------------------
|
||||
total_duration = sum(item["duration"] for item in group)
|
||||
group_ids = [item.get("index", item.get("id")) for item in group]
|
||||
language = group[0].get("language", "")
|
||||
|
||||
# 5) Append to output lists ----------------------------------------
|
||||
grouped_speech_tokens.append(concat_code)
|
||||
grouped_messages.append(message_list_item)
|
||||
grouped_durations.append(total_duration)
|
||||
grouped_ids.append(group_ids)
|
||||
grouped_lang.append(language)
|
||||
|
||||
return {
|
||||
"speech_tokens": grouped_speech_tokens,
|
||||
"messages": grouped_messages,
|
||||
"durations": grouped_durations,
|
||||
"ids": grouped_ids,
|
||||
"lang": grouped_lang,
|
||||
}
|
||||
|
||||
def data_collator_ultra_chat(batch):
|
||||
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
@ -550,7 +485,7 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
assert params.deepspeed and world_size > 1
|
||||
# assert params.deepspeed and world_size > 1
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=params, model=model, model_parameters=model.parameters()
|
||||
|
Loading…
x
Reference in New Issue
Block a user