This commit is contained in:
Piotr Żelasko 2022-01-16 00:03:17 +00:00
parent 186f5f1ba4
commit 2027b55233
4 changed files with 9 additions and 6 deletions

View File

@ -717,7 +717,7 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) AsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)

View File

@ -152,7 +152,7 @@ def generate_lexicon(
lexicon.append((word, pieces)) lexicon.append((word, pieces))
# The OOV word is <UNK> # The OOV word is <UNK>
lexicon.append(("<UNK>", [sp.id_to_piece(sp.unk_id())])) lexicon.append(("[UNK]", [sp.id_to_piece(sp.unk_id())]))
token2id: Dict[str, int] = dict() token2id: Dict[str, int] = dict()
for i in range(sp.vocab_size()): for i in range(sp.vocab_size()):
@ -197,7 +197,7 @@ def main():
words = word_sym_table.symbols words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"] excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "[UNK]", "#0", "<s>", "</s>"]
for w in excluded: for w in excluded:
if w in words: if w in words:
words.remove(w) words.remove(w)

View File

@ -103,6 +103,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# to data/musan # to data/musan
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests lhotse prepare musan $dl_dir/musan data/manifests
lhotse combine data/manifests/recordings_{music,speech,noise}.json data/manifests/recordings_musan.jsonl.gz
lhotse cut simple -r data/manifests/recordings_musan.jsonl.gz data/manifests/musan_cuts.jsonl.gz
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
@ -194,11 +196,11 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
>> $lang_dir/words.txt >> $lang_dir/words.txt
# Add remaining special word symbols expected by LM scripts. # Add remaining special word symbols expected by LM scripts.
num_words=$(wc -l $lang_dir/words.txt) num_words=$(cat $lang_dir/words.txt | wc -l)
echo "<s> ${num_words}" >> $lang_dir/words.txt echo "<s> ${num_words}" >> $lang_dir/words.txt
num_words=$(wc -l $lang_dir/words.txt) num_words=$(cat $lang_dir/words.txt | wc -l)
echo "</s> ${num_words}" >> $lang_dir/words.txt echo "</s> ${num_words}" >> $lang_dir/words.txt
num_words=$(wc -l $lang_dir/words.txt) num_words=$(cat $lang_dir/words.txt | wc -l)
echo "#0 ${num_words}" >> $lang_dir/words.txt echo "#0 ${num_words}" >> $lang_dir/words.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then

View File

@ -167,6 +167,7 @@ class AsrDataModule:
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
drop_last=True, drop_last=True,
) )
train_sampler.filter(lambda cut: 1.0 <= cut.duration <= 15.0)
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
train_dl = DataLoader( train_dl = DataLoader(