mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add alimeeting norm
This commit is contained in:
parent
b594a3875b
commit
9a5fc2ab71
@ -775,6 +775,42 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
|
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||||
|
"""
|
||||||
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
|
"""
|
||||||
|
if normalize == "none":
|
||||||
|
return text
|
||||||
|
elif normalize == "m2met":
|
||||||
|
import re
|
||||||
|
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
text = text.replace("<sil>", "")
|
||||||
|
text = text.replace("<%>", "")
|
||||||
|
text = text.replace("<->", "")
|
||||||
|
text = text.replace("<$>", "")
|
||||||
|
text = text.replace("<#>", "")
|
||||||
|
text = text.replace("<_>", "")
|
||||||
|
text = text.replace("<space>", "")
|
||||||
|
text = text.replace("`", "")
|
||||||
|
text = text.replace("&", "")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
if re.search("[a-zA-Z]", text):
|
||||||
|
text = text.upper()
|
||||||
|
text = text.replace("A", "A")
|
||||||
|
text = text.replace("a", "A")
|
||||||
|
text = text.replace("b", "B")
|
||||||
|
text = text.replace("c", "C")
|
||||||
|
text = text.replace("k", "K")
|
||||||
|
text = text.replace("t", "T")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
text = text.replace("丶", "")
|
||||||
|
text = text.replace("。", "")
|
||||||
|
text = text.replace("、", "")
|
||||||
|
text = text.replace("?", "")
|
||||||
|
return text
|
||||||
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
@ -788,6 +824,9 @@ def compute_loss(
|
|||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
# remove spaces in texts
|
||||||
|
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||||
|
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user