mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
dfecc5b81a
commit
352e1d221a
@ -0,0 +1,90 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
|
"""
|
||||||
|
Convert a transcript based on words to a list of BPE ids.
|
||||||
|
|
||||||
|
For example, if we use 2 as the encoding id of <unk>:
|
||||||
|
|
||||||
|
texts = ['this is a <unk> day']
|
||||||
|
spm_ids = [[38, 33, 6, 2, 316]]
|
||||||
|
|
||||||
|
texts = ['<unk> this is a sunny day']
|
||||||
|
spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316]]
|
||||||
|
|
||||||
|
texts = ['<unk>']
|
||||||
|
spm_ids = [[2]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_texts_into_ids(
|
||||||
|
texts: List[str],
|
||||||
|
unk_id: int,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
|
||||||
|
unk_id:
|
||||||
|
A number id for the token '<unk>'.
|
||||||
|
Returns:
|
||||||
|
Return an integer list of bpe ids.
|
||||||
|
"""
|
||||||
|
y = []
|
||||||
|
for text in texts:
|
||||||
|
y_ids = []
|
||||||
|
if "<unk>" in text:
|
||||||
|
text_segments = text.split("<unk>")
|
||||||
|
id_segments = sp.encode(text_segments, out_type=int)
|
||||||
|
for i in range(len(id_segments)):
|
||||||
|
if i != len(id_segments) - 1:
|
||||||
|
y_ids.extend(id_segments[i] + [unk_id])
|
||||||
|
else:
|
||||||
|
y_ids.extend(id_segments[i])
|
||||||
|
else:
|
||||||
|
y_ids = sp.encode(text, out_type=int)
|
||||||
|
y.append(y_ids)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
texts = args.texts
|
||||||
|
bpe_model = args.bpe_model
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(bpe_model)
|
||||||
|
unk_id = sp.piece_to_id("<unk>")
|
||||||
|
|
||||||
|
y = convert_texts_into_ids(
|
||||||
|
texts=texts,
|
||||||
|
unk_id=unk_id,
|
||||||
|
sp=sp,
|
||||||
|
)
|
||||||
|
logging.info(f"The input texts: {texts}")
|
||||||
|
logging.info(f"The encoding ids: {y}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user