Merge branch 'k2-fsa:master' into master

This commit is contained in:
Zengwei Yao 2022-08-29 15:18:11 +08:00 committed by GitHub
commit 077719c9ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 186 additions and 123 deletions

View File

@ -43,7 +43,7 @@ torch.set_num_interop_threads(1)
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
src_dir = Path("data/manifests") src_dir = Path("data/manifests/aidatatang_200zh")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -50,28 +50,19 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aidatatang_200zh" log "Stage 2: Prepare musan manifest"
if [ ! -f data/fbank/aidatatang_200zh/.fbank.done ]; then # We assume that you have downloaded the musan corpus
mkdir -p data/fbank/aidatatang_200zh # to data/musan
lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh if [ ! -f data/manifests/.manifests.done ]; then
touch data/fbank/aidatatang_200zh/.fbank.done log "It may take 6 minutes"
mkdir -p data/manifests/
lhotse prepare musan $dl_dir/musan data/manifests/
touch data/manifests/.manifests.done
fi fi
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare musan manifest" log "Stage 3: Compute fbank for musan"
# We assume that you have downloaded the musan corpus
# to data/musan
if [ ! -f data/manifests/.musan_manifests.done ]; then
log "It may take 6 minutes"
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_manifests.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then if [ ! -f data/fbank/.msuan.done ]; then
mkdir -p data/fbank mkdir -p data/fbank
./local/compute_fbank_musan.py ./local/compute_fbank_musan.py
@ -79,8 +70,8 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi fi
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 5: Compute fbank for aidatatang_200zh" log "Stage 4: Compute fbank for aidatatang_200zh"
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
mkdir -p data/fbank mkdir -p data/fbank
./local/compute_fbank_aidatatang_200zh.py ./local/compute_fbank_aidatatang_200zh.py
@ -88,31 +79,38 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi fi
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 6: Prepare char based lang" log "Stage 5: Prepare char based lang"
lang_char_dir=data/lang_char lang_char_dir=data/lang_char
mkdir -p $lang_char_dir mkdir -p $lang_char_dir
# Prepare text. # Prepare text.
grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \ # Note: in Linux, you can install jq with the following command:
| sed -e 's/["text:\t ]*//g' | sed 's/,//g' \ # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
| ./local/text2token.py -t "char" > $lang_char_dir/text # 2. chmod +x ./jq
# 3. cp jq /usr/bin
if [ ! -f $lang_char_dir/text ]; then
gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \
|jq '.text' |sed -e 's/["text:\t ]*//g' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text
fi
# Prepare words.txt # Prepare words.txt
grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \ if [ ! -f $lang_char_dir/text_words ]; then
| sed -e 's/["text:\t]*//g' | sed 's/,//g' \ gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \
| ./local/text2token.py -t "char" > $lang_char_dir/text_words | jq '.text' | sed -e 's/["text:\t]*//g' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_words
fi
cat $lang_char_dir/text_words | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ cat $lang_char_dir/text_words | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
| uniq > $lang_char_dir/words_no_ids.txt | uniq > $lang_char_dir/words_no_ids.txt
if [ ! -f $lang_char_dir/words.txt ]; then if [ ! -f $lang_char_dir/words.txt ]; then
./local/prepare_words.py \ ./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt --input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt --output-file $lang_char_dir/words.txt
fi fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py ./local/prepare_char.py
fi fi
fi fi

View File

@ -522,63 +522,14 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
dev = "dev" dev_cuts = aidatatang_200zh.valid_cuts()
test = "test" test_cuts = aidatatang_200zh.test_cuts()
dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts)
if not os.path.exists(f"{dev}/shared-0.tar"): test_dl = aidatatang_200zh.test_dataloaders(test_cuts)
os.makedirs(dev)
dev_cuts = aidatatang_200zh.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test}/shared-0.tar"):
os.makedirs(test)
test_cuts = aidatatang_200zh.test_cuts()
export_to_webdataset(
test_cuts,
output_path=f"{test}/shared-%d.tar",
shard_size=300,
)
dev_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
]
cuts_test_webdataset = CutSet.from_webdataset(
test_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset)
test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
test_sets = ["dev", "test"] test_sets = ["dev", "test"]
test_dl = [dev_dl, test_dl] test_dl = [dev_dl, test_dl]

View File

@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -63,6 +63,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -43,7 +43,7 @@ torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80): def compute_fbank_alimeeting(num_mel_bins: int = 80):
src_dir = Path("data/manifests") src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -63,6 +63,13 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -30,9 +30,11 @@ with word segmenting:
import argparse import argparse
import paddle
import jieba import jieba
from tqdm import tqdm from tqdm import tqdm
paddle.enable_static()
jieba.enable_paddle() jieba.enable_paddle()

View File

@ -107,7 +107,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# Prepare text. # Prepare text.
# Note: in Linux, you can install jq with the following command: # Note: in Linux, you can install jq with the following command:
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
gunzip -c data/manifests/alimeeting/supervisions_train.jsonl.gz \ gunzip -c data/manifests/alimeeting/alimeeting_supervisions_train.jsonl.gz \
| jq ".text" | sed 's/"//g' \ | jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text | ./local/text2token.py -t "char" > $lang_char_dir/text

View File

@ -62,6 +62,13 @@ def preprocess_giga_speech():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items(): for partition, m in manifests.items():
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"

View File

@ -81,9 +81,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" ==
# or # or
# pip install multi_quantization # pip install multi_quantization
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('multi_quantization') is not None)")
if [ $has_quantization == 'False' ]; then if [ $has_quantization == 'False' ]; then
log "Please install quantization before running following stages" log "Please install multi_quantization before running following stages"
exit 1 exit 1
fi fi

View File

@ -66,6 +66,13 @@ def compute_fbank_librispeech():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -65,6 +65,8 @@ def compute_fbank_musan():
assert len(manifests) == len(dataset_parts), ( assert len(manifests) == len(dataset_parts), (
len(manifests), len(manifests),
len(dataset_parts), len(dataset_parts),
list(manifests.keys()),
dataset_parts,
) )
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"

View File

@ -68,6 +68,13 @@ def preprocess_giga_speech():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items(): for partition, m in manifests.items():
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"

View File

@ -164,6 +164,10 @@ class Eve(Optimizer):
p.mul_(1 - (weight_decay * is_above_target_rms)) p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size) p.addcdiv_(exp_avg, denom, value=-step_size)
# Constrain the range of scalar weights
if p.numel() == 1:
p.clamp_(min=-10, max=2)
return loss return loss

View File

@ -652,13 +652,13 @@ def main():
# Also export encoder/decoder/joiner separately # Also export encoder/decoder/joiner separately
encoder_filename = params.exp_dir / "encoder_jit_script.pt" encoder_filename = params.exp_dir / "encoder_jit_script.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename) export_encoder_model_jit_script(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_script.pt" decoder_filename = params.exp_dir / "decoder_jit_script.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename) export_decoder_model_jit_script(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_script.pt" joiner_filename = params.exp_dir / "joiner_jit_script.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename) export_joiner_model_jit_script(model.joiner, joiner_filename)
elif params.jit_trace is True: elif params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True)

View File

@ -181,7 +181,7 @@ def test_convert_scaled_to_non_scaled():
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U)) y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y) d1 = model.decoder(y)
d2 = model.decoder(y) d2 = converted_model.decoder(y)
assert torch.allclose(d1, d2) assert torch.allclose(d1, d2)

View File

@ -81,18 +81,17 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
# hyps is a list, every element is decode result of a sentence.
hyps = hubert_model.ctc_greedy_search(batch) hyps = hubert_model.ctc_greedy_search(batch)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
assert len(hyps) == len(texts) cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
this_batch = [] this_batch = []
assert len(hyps) == len(texts)
for hyp_text, ref_text in zip(hyps, texts): for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
hyp_words = hyp_text.split() hyp_words = hyp_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results["ctc_greedy_search"].extend(this_batch) results["ctc_greedy_search"].extend(this_batch)
num_cuts += len(texts) num_cuts += len(texts)

View File

@ -28,7 +28,7 @@ from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import quantization import multi_quantization as quantization
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned from hubert_xlarge import HubertXlargeFineTuned

View File

@ -69,6 +69,13 @@ def compute_fbank_musan():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = src_dir / "cuts_musan.jsonl.gz" musan_cuts_path = src_dir / "cuts_musan.jsonl.gz"
if musan_cuts_path.is_file(): if musan_cuts_path.is_file():

View File

@ -62,6 +62,13 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def compute_fbank_tedlium():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -63,6 +63,13 @@ def compute_fbank_timit():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -23,6 +23,8 @@ from pathlib import Path
from lhotse import CutSet, SupervisionSegment from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall import setup_logger
# Similar text filtering and normalization procedure as in: # Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh # https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
@ -48,13 +50,17 @@ def preprocess_wenet_speech():
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
# Note: By default, we preprocess all sub-parts.
# You can delete those that you don't need.
# For instance, if you don't want to use the L subpart, just remove
# the line below containing "L"
dataset_parts = ( dataset_parts = (
"L",
"M",
"S",
"DEV", "DEV",
"TEST_NET", "TEST_NET",
"TEST_MEETING", "TEST_MEETING",
"S",
"M",
"L",
) )
logging.info("Loading manifest (may take 10 minutes)") logging.info("Loading manifest (may take 10 minutes)")
@ -66,6 +72,13 @@ def preprocess_wenet_speech():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items(): for partition, m in manifests.items():
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
@ -81,10 +94,13 @@ def preprocess_wenet_speech():
logging.info(f"Normalizing text in {partition}") logging.info(f"Normalizing text in {partition}")
for sup in m["supervisions"]: for sup in m["supervisions"]:
text = str(sup.text) text = str(sup.text)
logging.info(f"Original text: {text}") orig_text = text
sup.text = normalize_text(sup.text) sup.text = normalize_text(sup.text)
text = str(sup.text) text = str(sup.text)
logging.info(f"Normalize text: {text}") if len(orig_text) != len(text):
logging.info(
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
)
# Create long-recording cut manifests. # Create long-recording cut manifests.
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
@ -109,12 +125,10 @@ def preprocess_wenet_speech():
def main(): def main():
formatter = ( setup_logger(log_filename="./log-preprocess-wenetspeech")
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_wenet_speech() preprocess_wenet_speech()
logging.info("Done")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -81,7 +81,6 @@ For training with the S subset:
import argparse import argparse
import logging import logging
import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -120,8 +119,6 @@ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
] ]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -162,7 +159,7 @@ def get_parser():
default=0, default=0,
help="""Resume training from from this epoch. help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from If it is positive, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
""", """,
) )
@ -361,8 +358,8 @@ def get_params() -> AttributeDict:
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 10, "batch_idx_train": 0,
"log_interval": 1, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
@ -545,7 +542,7 @@ def compute_loss(
warmup: float = 1.0, warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute RNN-T loss given the model and its inputs.
Args: Args:
params: params:
Parameters for training. See :func:`get_params`. Parameters for training. See :func:`get_params`.
@ -573,7 +570,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts) y = graph_compiler.texts_to_ids(texts)
if type(y) == list: if isinstance(y, list):
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
else: else:
y = y.to(device) y = y.to(device)
@ -697,7 +694,6 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])

View File

@ -61,7 +61,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse import argparse
import copy import copy
import logging import logging
import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -103,8 +102,6 @@ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
] ]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
@ -684,7 +681,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts) y = graph_compiler.texts_to_ids(texts)
if type(y) == list: if isinstance(y, list):
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
else: else:
y = y.to(device) y = y.to(device)

View File

@ -47,6 +47,13 @@ def compute_fbank_yesno():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank( extractor = Fbank(
FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins) FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
) )

View File

@ -130,6 +130,8 @@ class TensorDiagnostic(object):
x = x[0] x = x[0]
if not isinstance(x, Tensor): if not isinstance(x, Tensor):
return return
if x.numel() == 0: # for empty tensor
return
x = x.detach().clone() x = x.detach().clone()
if x.ndim == 0: if x.ndim == 0:
x = x.unsqueeze(0) x = x.unsqueeze(0)