mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
* Add byte level bpe tal_csasr recipe * Minor fixes to decoding and exporting * Fix prepare.sh * Update results
314 lines
3.9 KiB
Python
314 lines
3.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
#
|
|
# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py
|
|
|
|
import re
|
|
import unicodedata
|
|
|
|
|
|
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
|
SPACE = chr(32)
|
|
SPACE_ESCAPE = chr(9601)
|
|
BPE_UNK = chr(8263)
|
|
|
|
PRINTABLE_BASE_CHARS = [
|
|
256,
|
|
257,
|
|
258,
|
|
259,
|
|
260,
|
|
261,
|
|
262,
|
|
263,
|
|
264,
|
|
265,
|
|
266,
|
|
267,
|
|
268,
|
|
269,
|
|
270,
|
|
271,
|
|
272,
|
|
273,
|
|
274,
|
|
275,
|
|
276,
|
|
277,
|
|
278,
|
|
279,
|
|
280,
|
|
281,
|
|
282,
|
|
283,
|
|
284,
|
|
285,
|
|
286,
|
|
287,
|
|
32,
|
|
33,
|
|
34,
|
|
35,
|
|
36,
|
|
37,
|
|
38,
|
|
39,
|
|
40,
|
|
41,
|
|
42,
|
|
43,
|
|
44,
|
|
45,
|
|
46,
|
|
47,
|
|
48,
|
|
49,
|
|
50,
|
|
51,
|
|
52,
|
|
53,
|
|
54,
|
|
55,
|
|
56,
|
|
57,
|
|
58,
|
|
59,
|
|
60,
|
|
61,
|
|
62,
|
|
63,
|
|
64,
|
|
65,
|
|
66,
|
|
67,
|
|
68,
|
|
69,
|
|
70,
|
|
71,
|
|
72,
|
|
73,
|
|
74,
|
|
75,
|
|
76,
|
|
77,
|
|
78,
|
|
79,
|
|
80,
|
|
81,
|
|
82,
|
|
83,
|
|
84,
|
|
85,
|
|
86,
|
|
87,
|
|
88,
|
|
89,
|
|
90,
|
|
91,
|
|
92,
|
|
93,
|
|
94,
|
|
95,
|
|
96,
|
|
97,
|
|
98,
|
|
99,
|
|
100,
|
|
101,
|
|
102,
|
|
103,
|
|
104,
|
|
105,
|
|
106,
|
|
107,
|
|
108,
|
|
109,
|
|
110,
|
|
111,
|
|
112,
|
|
113,
|
|
114,
|
|
115,
|
|
116,
|
|
117,
|
|
118,
|
|
119,
|
|
120,
|
|
121,
|
|
122,
|
|
123,
|
|
124,
|
|
125,
|
|
126,
|
|
288,
|
|
289,
|
|
290,
|
|
291,
|
|
292,
|
|
293,
|
|
294,
|
|
295,
|
|
296,
|
|
297,
|
|
298,
|
|
299,
|
|
300,
|
|
301,
|
|
302,
|
|
303,
|
|
304,
|
|
305,
|
|
308,
|
|
309,
|
|
310,
|
|
311,
|
|
312,
|
|
313,
|
|
314,
|
|
315,
|
|
316,
|
|
317,
|
|
318,
|
|
321,
|
|
322,
|
|
323,
|
|
324,
|
|
325,
|
|
326,
|
|
327,
|
|
328,
|
|
330,
|
|
331,
|
|
332,
|
|
333,
|
|
334,
|
|
335,
|
|
336,
|
|
337,
|
|
338,
|
|
339,
|
|
340,
|
|
341,
|
|
342,
|
|
343,
|
|
344,
|
|
345,
|
|
346,
|
|
347,
|
|
348,
|
|
349,
|
|
350,
|
|
351,
|
|
352,
|
|
353,
|
|
354,
|
|
355,
|
|
356,
|
|
357,
|
|
358,
|
|
359,
|
|
360,
|
|
361,
|
|
362,
|
|
363,
|
|
364,
|
|
365,
|
|
366,
|
|
367,
|
|
368,
|
|
369,
|
|
370,
|
|
371,
|
|
372,
|
|
373,
|
|
374,
|
|
375,
|
|
376,
|
|
377,
|
|
378,
|
|
379,
|
|
380,
|
|
381,
|
|
382,
|
|
384,
|
|
385,
|
|
386,
|
|
387,
|
|
388,
|
|
389,
|
|
390,
|
|
391,
|
|
392,
|
|
393,
|
|
394,
|
|
395,
|
|
396,
|
|
397,
|
|
398,
|
|
399,
|
|
400,
|
|
401,
|
|
402,
|
|
403,
|
|
404,
|
|
405,
|
|
406,
|
|
407,
|
|
408,
|
|
409,
|
|
410,
|
|
411,
|
|
412,
|
|
413,
|
|
414,
|
|
415,
|
|
416,
|
|
417,
|
|
418,
|
|
419,
|
|
420,
|
|
421,
|
|
422,
|
|
]
|
|
|
|
for c in PRINTABLE_BASE_CHARS:
|
|
assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c
|
|
|
|
BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
|
|
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
|
|
BCHAR_TO_BYTE[BPE_UNK] = 32 # map unk to space
|
|
|
|
|
|
def byte_encode(x: str) -> str:
|
|
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
|
|
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
|
|
|
|
|
|
def byte_decode(x: str) -> str:
|
|
try:
|
|
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
|
|
except ValueError:
|
|
return ""
|
|
|
|
|
|
def smart_byte_decode(x: str) -> str:
|
|
output = byte_decode(x)
|
|
if output == "":
|
|
# DP the best recovery (max valid chars) if it's broken
|
|
n_bytes = len(x)
|
|
f = [0 for _ in range(n_bytes + 1)]
|
|
pt = [0 for _ in range(n_bytes + 1)]
|
|
for i in range(1, n_bytes + 1):
|
|
f[i], pt[i] = f[i - 1], i - 1
|
|
for j in range(1, min(4, i) + 1):
|
|
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
|
|
f[i], pt[i] = f[i - j] + 1, i - j
|
|
cur_pt = n_bytes
|
|
while cur_pt > 0:
|
|
if f[cur_pt] == f[pt[cur_pt]] + 1:
|
|
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
|
|
cur_pt = pt[cur_pt]
|
|
return output
|