diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 447ca122f..ea8ad5000 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -775,6 +775,42 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ + def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + if normalize == "none": + return text + elif normalize == "m2met": + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) @@ -788,6 +824,9 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] + # remove spaces in texts + texts = [normalize_text_alimeeting(text) for text in texts] + y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y)