add alimeeting norm

This commit is contained in:
Yuekai Zhang 2024-06-26 10:41:49 +08:00
parent b594a3875b
commit 9a5fc2ab71

View File

@ -775,6 +775,42 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
@ -788,6 +824,9 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)