mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Fix code according to review
This commit is contained in:
parent
73a31db1b2
commit
66afaf402d
@ -201,7 +201,7 @@ def get_parser():
|
||||
"--rnn-lm-tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True share the weights between the input embedding layer and the
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
@ -235,7 +235,7 @@ def get_params() -> AttributeDict:
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
rnn_lm_model: nn.Module,
|
||||
rnn_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
|
@ -1,146 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
|
||||
# Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script takes a `bpe.model` and a text file such as
|
||||
`download/ptb.train.txt`,
|
||||
and outputs the LM training data to a supplied directory such
|
||||
as data/bpe_500. The format is as follows:
|
||||
|
||||
It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
|
||||
representation of a dict with the following format:
|
||||
|
||||
'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32
|
||||
containing the BPE representations of each word, indexed by
|
||||
integer word ID. (These integer word IDS are present in
|
||||
'lm_data'). The sentencepiece object can be used to turn the
|
||||
words and BPE units into string form.
|
||||
'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype
|
||||
torch.int32 containing all the sentences, as word-ids (we don't
|
||||
output the string form of this directly but it can be worked out
|
||||
together with 'words' and the bpe.model).
|
||||
'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing
|
||||
number of BPE tokens of each sentence.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="Input BPE model, e.g. data/bpe_500/bpe.model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm-data",
|
||||
type=str,
|
||||
help="""Input LM training data as text, e.g.
|
||||
download/pb.train.txt""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm-archive",
|
||||
type=str,
|
||||
help="""Path to output archive, e.g. data/bpe_500/lm_data.pt;
|
||||
look at the source of this script to see the format.""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
if Path(args.lm_archive).exists():
|
||||
logging.warning(f"{args.lm_archive} exists - skipping")
|
||||
return
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
# word2index is a dictionary from words to integer ids. No need to reserve
|
||||
# space for epsilon, etc.; the words are just used as a convenient way to
|
||||
# compress the sequences of BPE pieces.
|
||||
word2index = dict()
|
||||
|
||||
word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces.
|
||||
|
||||
# ptb.train.txt has already converted oov words to <unk>
|
||||
word2bpe.append([sp.unk_id()])
|
||||
word2index["<unk>"] = 0
|
||||
|
||||
sentences = [] # Will be a list-of-list-of-int, representing word-ids.
|
||||
|
||||
with open(args.lm_data) as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if line == "":
|
||||
break
|
||||
line_words = line.split()
|
||||
for w in line_words:
|
||||
if w not in word2index:
|
||||
w_bpe = sp.encode(w)
|
||||
word2index[w] = len(word2bpe)
|
||||
word2bpe.append(w_bpe)
|
||||
sentences.append([word2index[w] for w in line_words])
|
||||
|
||||
words = k2.ragged.RaggedTensor(word2bpe)
|
||||
sentences = k2.ragged.RaggedTensor(sentences)
|
||||
|
||||
output = dict(words=words, sentences=sentences)
|
||||
|
||||
num_sentences = sentences.dim0
|
||||
sentence_lengths = [0] * num_sentences
|
||||
for i in range(num_sentences):
|
||||
word_ids = sentences[i]
|
||||
|
||||
# NOTE: If word_ids is a tensor with only 1 entry,
|
||||
# token_ids is a torch.Tensor
|
||||
token_ids = words[word_ids]
|
||||
if isinstance(token_ids, k2.RaggedTensor):
|
||||
token_ids = token_ids.values
|
||||
|
||||
# token_ids is a 1-D tensor containing the BPE tokens
|
||||
# of the current sentence
|
||||
|
||||
sentence_lengths[i] = token_ids.numel()
|
||||
|
||||
output["sentence_lengths"] = torch.tensor(
|
||||
sentence_lengths, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.save(output, args.lm_archive)
|
||||
logging.info(f"Saved to {args.lm_archive}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
1
egs/ptb/LM/local/prepare_lm_training_data.py
Symbolic link
1
egs/ptb/LM/local/prepare_lm_training_data.py
Symbolic link
@ -0,0 +1 @@
|
||||
/Users/ezerhoun/repos/open_source/icefall/egs/librispeech/ASR/local/prepare_lm_training_data.py
|
@ -1,95 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
#
|
||||
# Due to an issue reported in
|
||||
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
help="Vocabulary size for BPE training",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
vocab_size = args.vocab_size
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{args.out_dir}/{model_type}_{vocab_size}"
|
||||
train_text = args.transcript
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
unk_id = len(user_defined_symbols)
|
||||
# Note: unk_id is fixed to 2.
|
||||
# If you change it, you should also change other
|
||||
# places that are using it.
|
||||
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if not model_file.is_file():
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=train_text,
|
||||
vocab_size=vocab_size,
|
||||
model_type=model_type,
|
||||
model_prefix=model_prefix,
|
||||
input_sentence_size=input_sentence_size,
|
||||
character_coverage=character_coverage,
|
||||
user_defined_symbols=user_defined_symbols,
|
||||
unk_id=unk_id,
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
)
|
||||
|
||||
shutil.copyfile(model_file, f"{args.out_dir}/bpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/ptb/LM/local/train_bpe_model.py
Symbolic link
1
egs/ptb/LM/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
/Users/ezerhoun/repos/open_source/icefall/egs/librispeech/ASR/local/train_bpe_model.py
|
Loading…
x
Reference in New Issue
Block a user