From 66afaf402d0ab79bc73b294ad8993455fd2fd437 Mon Sep 17 00:00:00 2001 From: Erwan Date: Wed, 22 Jun 2022 15:04:25 +0200 Subject: [PATCH] Fix code according to review --- egs/librispeech/ASR/RESULTS.md | 2 +- egs/librispeech/ASR/conformer_ctc/decode.py | 4 +- egs/ptb/LM/local/prepare_lm_training_data.py | 147 +------------------ egs/ptb/LM/local/train_bpe_model.py | 96 +----------- 4 files changed, 5 insertions(+), 244 deletions(-) mode change 100755 => 120000 egs/ptb/LM/local/prepare_lm_training_data.py mode change 100755 => 120000 egs/ptb/LM/local/train_bpe_model.py diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 2b5948171..b11ac867d 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1202,7 +1202,7 @@ rnn_dir=$(git rev-parse --show-toplevel)/icefall/rnn_lm --epoch 77 \ --avg 55 \ --nbest-scale 0.5 \ - --rnn-lm-exp-dir ${rnn_dir}/exp_2048_3_tied\ + --rnn-lm-exp-dir ${rnn_dir}/exp_2048_3_tied \ --rnn-lm-epoch 29 \ --rnn-lm-avg 3 \ --rnn-lm-embedding-dim 2048 \ diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 1ecd6c220..0e8247b8d 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -201,7 +201,7 @@ def get_parser(): "--rnn-lm-tie-weights", type=str2bool, default=False, - help="""True share the weights between the input embedding layer and the + help="""True to share the weights between the input embedding layer and the last output linear layer """, ) @@ -235,7 +235,7 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - rnn_lm_model: nn.Module, + rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], diff --git a/egs/ptb/LM/local/prepare_lm_training_data.py b/egs/ptb/LM/local/prepare_lm_training_data.py deleted file mode 100755 index bc7555209..000000000 --- a/egs/ptb/LM/local/prepare_lm_training_data.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey -# Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes a `bpe.model` and a text file such as -`download/ptb.train.txt`, -and outputs the LM training data to a supplied directory such -as data/bpe_500. The format is as follows: - -It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a -representation of a dict with the following format: - - 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 - containing the BPE representations of each word, indexed by - integer word ID. (These integer word IDS are present in - 'lm_data'). The sentencepiece object can be used to turn the - words and BPE units into string form. - 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype - torch.int32 containing all the sentences, as word-ids (we don't - output the string form of this directly but it can be worked out - together with 'words' and the bpe.model). - 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing - number of BPE tokens of each sentence. -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import sentencepiece as spm -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--bpe-model", - type=str, - help="Input BPE model, e.g. data/bpe_500/bpe.model", - ) - parser.add_argument( - "--lm-data", - type=str, - help="""Input LM training data as text, e.g. - download/pb.train.txt""", - ) - parser.add_argument( - "--lm-archive", - type=str, - help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; - look at the source of this script to see the format.""", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - if Path(args.lm_archive).exists(): - logging.warning(f"{args.lm_archive} exists - skipping") - return - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - # word2index is a dictionary from words to integer ids. No need to reserve - # space for epsilon, etc.; the words are just used as a convenient way to - # compress the sequences of BPE pieces. - word2index = dict() - - word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. - - # ptb.train.txt has already converted oov words to - word2bpe.append([sp.unk_id()]) - word2index[""] = 0 - - sentences = [] # Will be a list-of-list-of-int, representing word-ids. - - with open(args.lm_data) as f: - while True: - line = f.readline() - if line == "": - break - line_words = line.split() - for w in line_words: - if w not in word2index: - w_bpe = sp.encode(w) - word2index[w] = len(word2bpe) - word2bpe.append(w_bpe) - sentences.append([word2index[w] for w in line_words]) - - words = k2.ragged.RaggedTensor(word2bpe) - sentences = k2.ragged.RaggedTensor(sentences) - - output = dict(words=words, sentences=sentences) - - num_sentences = sentences.dim0 - sentence_lengths = [0] * num_sentences - for i in range(num_sentences): - word_ids = sentences[i] - - # NOTE: If word_ids is a tensor with only 1 entry, - # token_ids is a torch.Tensor - token_ids = words[word_ids] - if isinstance(token_ids, k2.RaggedTensor): - token_ids = token_ids.values - - # token_ids is a 1-D tensor containing the BPE tokens - # of the current sentence - - sentence_lengths[i] = token_ids.numel() - - output["sentence_lengths"] = torch.tensor( - sentence_lengths, dtype=torch.int32 - ) - - torch.save(output, args.lm_archive) - logging.info(f"Saved to {args.lm_archive}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ptb/LM/local/prepare_lm_training_data.py b/egs/ptb/LM/local/prepare_lm_training_data.py new file mode 120000 index 000000000..e2afc5240 --- /dev/null +++ b/egs/ptb/LM/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +/Users/ezerhoun/repos/open_source/icefall/egs/librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/ptb/LM/local/train_bpe_model.py b/egs/ptb/LM/local/train_bpe_model.py deleted file mode 100755 index 8d87707a9..000000000 --- a/egs/ptb/LM/local/train_bpe_model.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--out-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - vocab_size = args.vocab_size - model_type = "unigram" - - model_prefix = f"{args.out_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - character_coverage = 1.0 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - - shutil.copyfile(model_file, f"{args.out_dir}/bpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/ptb/LM/local/train_bpe_model.py b/egs/ptb/LM/local/train_bpe_model.py new file mode 120000 index 000000000..d9afbc014 --- /dev/null +++ b/egs/ptb/LM/local/train_bpe_model.py @@ -0,0 +1 @@ +/Users/ezerhoun/repos/open_source/icefall/egs/librispeech/ASR/local/train_bpe_model.py \ No newline at end of file