diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index 9850cf251..fb2751c0f 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -43,7 +43,7 @@ torch.set_num_interop_threads(1) def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): - src_dir = Path("data/manifests") + src_dir = Path("data/manifests/aidatatang_200zh") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 3da783006..039951354 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -50,28 +50,19 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Process aidatatang_200zh" - if [ ! -f data/fbank/aidatatang_200zh/.fbank.done ]; then - mkdir -p data/fbank/aidatatang_200zh - lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh - touch data/fbank/aidatatang_200zh/.fbank.done + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + if [ ! -f data/manifests/.manifests.done ]; then + log "It may take 6 minutes" + mkdir -p data/manifests/ + lhotse prepare musan $dl_dir/musan data/manifests/ + touch data/manifests/.manifests.done fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - if [ ! -f data/manifests/.musan_manifests.done ]; then - log "It may take 6 minutes" - mkdir -p data/manifests - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan_manifests.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for musan" + log "Stage 3: Compute fbank for musan" if [ ! -f data/fbank/.msuan.done ]; then mkdir -p data/fbank ./local/compute_fbank_musan.py @@ -79,8 +70,8 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi fi -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Compute fbank for aidatatang_200zh" +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for aidatatang_200zh" if [ ! -f data/fbank/.aidatatang_200zh.done ]; then mkdir -p data/fbank ./local/compute_fbank_aidatatang_200zh.py @@ -88,31 +79,38 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi fi -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare char based lang" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare char based lang" lang_char_dir=data/lang_char mkdir -p $lang_char_dir - # Prepare text. - grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \ - | sed -e 's/["text:\t ]*//g' | sed 's/,//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text - + # Note: in Linux, you can install jq with the following command: + # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + # 2. chmod +x ./jq + # 3. cp jq /usr/bin + if [ ! -f $lang_char_dir/text ]; then + gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \ + |jq '.text' |sed -e 's/["text:\t ]*//g' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $lang_char_dir/text + fi # Prepare words.txt - grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \ - | sed -e 's/["text:\t]*//g' | sed 's/,//g' \ - | ./local/text2token.py -t "char" > $lang_char_dir/text_words + if [ ! -f $lang_char_dir/text_words ]; then + gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \ + | jq '.text' | sed -e 's/["text:\t]*//g' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $lang_char_dir/text_words + fi cat $lang_char_dir/text_words | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ | uniq > $lang_char_dir/words_no_ids.txt if [ ! -f $lang_char_dir/words.txt ]; then ./local/prepare_words.py \ - --input-file $lang_char_dir/words_no_ids.txt - --output-file $lang_char_dir/words.txt + --input-file $lang_char_dir/words_no_ids.txt \ + --output-file $lang_char_dir/words.txt fi if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py fi fi + diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index a185567da..f0407f429 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -522,63 +522,14 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # Note: Please use "pip install webdataset==0.1.103" - # for installing the webdataset. - import glob - import os - - from lhotse import CutSet - from lhotse.dataset.webdataset import export_to_webdataset - # we need cut ids to display recognition results. args.return_cuts = True aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) - dev = "dev" - test = "test" - - if not os.path.exists(f"{dev}/shared-0.tar"): - os.makedirs(dev) - dev_cuts = aidatatang_200zh.valid_cuts() - export_to_webdataset( - dev_cuts, - output_path=f"{dev}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test}/shared-0.tar"): - os.makedirs(test) - test_cuts = aidatatang_200zh.test_cuts() - export_to_webdataset( - test_cuts, - output_path=f"{test}/shared-%d.tar", - shard_size=300, - ) - - dev_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) - ] - cuts_dev_webdataset = CutSet.from_webdataset( - dev_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) - ] - cuts_test_webdataset = CutSet.from_webdataset( - test_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset) - test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset) + dev_cuts = aidatatang_200zh.valid_cuts() + test_cuts = aidatatang_200zh.test_cuts() + dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts) + test_dl = aidatatang_200zh.test_dataloaders(test_cuts) test_sets = ["dev", "test"] test_dl = [dev_dl, test_dl] diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py index 8cdfad71f..42700a972 100755 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py @@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index e27e35ec5..deab6c809 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -62,6 +62,13 @@ def compute_fbank_aishell(num_mel_bins: int = 80): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index 7bc969a1a..d8d3622bd 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -62,6 +62,13 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 09f885636..3f50d9e3e 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -63,6 +63,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index 2ff473c60..af926aa53 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -43,7 +43,7 @@ torch.set_num_interop_threads(1) def compute_fbank_alimeeting(num_mel_bins: int = 80): - src_dir = Path("data/manifests") + src_dir = Path("data/manifests/alimeeting") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -63,6 +63,13 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py index 3df727c67..7c1019aa8 100644 --- a/egs/alimeeting/ASR/local/text2segments.py +++ b/egs/alimeeting/ASR/local/text2segments.py @@ -30,9 +30,11 @@ with word segmenting: import argparse +import paddle import jieba from tqdm import tqdm +paddle.enable_static() jieba.enable_paddle() diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index eb2ac697d..17224bb68 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -107,7 +107,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then # Prepare text. # Note: in Linux, you can install jq with the following command: # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 - gunzip -c data/manifests/alimeeting/supervisions_train.jsonl.gz \ + gunzip -c data/manifests/alimeeting/alimeeting_supervisions_train.jsonl.gz \ | jq ".text" | sed 's/"//g' \ | ./local/text2token.py -t "char" > $lang_char_dir/text diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 0cec82ad5..48d10a157 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -62,6 +62,13 @@ def preprocess_giga_speech(): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + for partition, m in manifests.items(): logging.info(f"Processing {partition}") raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index 9c47e8eae..2a69d3921 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -81,9 +81,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == # or # pip install multi_quantization - has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") + has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('multi_quantization') is not None)") if [ $has_quantization == 'False' ]; then - log "Please install quantization before running following stages" + log "Please install multi_quantization before running following stages" exit 1 fi diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 642d9fd32..f3e15e039 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -66,6 +66,13 @@ def compute_fbank_librispeech(): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index fef372129..056da29e5 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -65,6 +65,8 @@ def compute_fbank_musan(): assert len(manifests) == len(dataset_parts), ( len(manifests), len(dataset_parts), + list(manifests.keys()), + dataset_parts, ) musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 0f4ae820b..077f23039 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -68,6 +68,13 @@ def preprocess_giga_speech(): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + for partition, m in manifests.items(): logging.info(f"Processing {partition}") raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 432bf8220..041a81f45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -164,6 +164,10 @@ class Eve(Optimizer): p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) + # Constrain the range of scalar weights + if p.numel() == 1: + p.clamp_(min=-10, max=2) + return loss diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index fb3db282a..a4687f35d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -652,13 +652,13 @@ def main(): # Also export encoder/decoder/joiner separately encoder_filename = params.exp_dir / "encoder_jit_script.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) + export_encoder_model_jit_script(model.encoder, encoder_filename) decoder_filename = params.exp_dir / "decoder_jit_script.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) + export_decoder_model_jit_script(model.decoder, decoder_filename) joiner_filename = params.exp_dir / "joiner_jit_script.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) + export_joiner_model_jit_script(model.joiner, joiner_filename) elif params.jit_trace is True: convert_scaled_to_non_scaled(model, inplace=True) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py index a9feea83c..2e131158f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py @@ -181,7 +181,7 @@ def test_convert_scaled_to_non_scaled(): y = torch.randint(low=1, high=vocab_size - 1, size=(N, U)) d1 = model.decoder(y) - d2 = model.decoder(y) + d2 = converted_model.decoder(y) assert torch.allclose(d1, d2) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 10b0e5edc..49b557814 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -81,18 +81,17 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - + # hyps is a list, every element is decode result of a sentence. hyps = hubert_model.ctc_greedy_search(batch) texts = batch["supervisions"]["text"] - assert len(hyps) == len(texts) + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] this_batch = [] - - for hyp_text, ref_text in zip(hyps, texts): + assert len(hyps) == len(texts) + for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() hyp_words = hyp_text.split() - this_batch.append((ref_words, hyp_words)) - + this_batch.append((cut_id, ref_words, hyp_words)) results["ctc_greedy_search"].extend(this_batch) num_cuts += len(texts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index e3dcd039b..65895c920 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -28,7 +28,7 @@ from typing import List, Tuple import numpy as np import torch import torch.multiprocessing as mp -import quantization +import multi_quantization as quantization from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py index 70372af2b..6cb8b65ae 100755 --- a/egs/spgispeech/ASR/local/compute_fbank_musan.py +++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py @@ -69,6 +69,13 @@ def compute_fbank_musan(): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + musan_cuts_path = src_dir / "cuts_musan.jsonl.gz" if musan_cuts_path.is_file(): diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py index 367e098f7..4582609ac 100755 --- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py +++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py @@ -62,6 +62,13 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py index e324b5025..327962a79 100755 --- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py +++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py @@ -62,6 +62,13 @@ def compute_fbank_tedlium(): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py index 094769c8c..f25786a0c 100644 --- a/egs/timit/ASR/local/compute_fbank_timit.py +++ b/egs/timit/ASR/local/compute_fbank_timit.py @@ -63,6 +63,13 @@ def compute_fbank_timit(): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index 64733eb15..817969c47 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -23,6 +23,8 @@ from pathlib import Path from lhotse import CutSet, SupervisionSegment from lhotse.recipes.utils import read_manifests_if_cached +from icefall import setup_logger + # Similar text filtering and normalization procedure as in: # https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh @@ -48,13 +50,17 @@ def preprocess_wenet_speech(): output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) + # Note: By default, we preprocess all sub-parts. + # You can delete those that you don't need. + # For instance, if you don't want to use the L subpart, just remove + # the line below containing "L" dataset_parts = ( - "L", - "M", - "S", "DEV", "TEST_NET", "TEST_MEETING", + "S", + "M", + "L", ) logging.info("Loading manifest (may take 10 minutes)") @@ -66,6 +72,13 @@ def preprocess_wenet_speech(): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + for partition, m in manifests.items(): logging.info(f"Processing {partition}") raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" @@ -81,10 +94,13 @@ def preprocess_wenet_speech(): logging.info(f"Normalizing text in {partition}") for sup in m["supervisions"]: text = str(sup.text) - logging.info(f"Original text: {text}") + orig_text = text sup.text = normalize_text(sup.text) text = str(sup.text) - logging.info(f"Normalize text: {text}") + if len(orig_text) != len(text): + logging.info( + f"\nOriginal text vs normalized text:\n{orig_text}\n{text}" + ) # Create long-recording cut manifests. logging.info(f"Processing {partition}") @@ -109,12 +125,10 @@ def preprocess_wenet_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - logging.basicConfig(format=formatter, level=logging.INFO) + setup_logger(log_filename="./log-preprocess-wenetspeech") preprocess_wenet_speech() + logging.info("Done") if __name__ == "__main__": diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 5208dbefe..d3cc7c9c9 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -81,7 +81,6 @@ For training with the S subset: import argparse import logging -import os import warnings from pathlib import Path from shutil import copyfile @@ -120,8 +119,6 @@ LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler ] -os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - def get_parser(): parser = argparse.ArgumentParser( @@ -162,7 +159,7 @@ def get_parser(): default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt + pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt """, ) @@ -361,8 +358,8 @@ def get_params() -> AttributeDict: "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, - "batch_idx_train": 10, - "log_interval": 1, + "batch_idx_train": 0, + "log_interval": 50, "reset_interval": 200, # parameters for conformer "feature_dim": 80, @@ -545,7 +542,7 @@ def compute_loss( warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute RNN-T loss given the model and its inputs. Args: params: Parameters for training. See :func:`get_params`. @@ -573,7 +570,7 @@ def compute_loss( texts = batch["supervisions"]["text"] y = graph_compiler.texts_to_ids(texts) - if type(y) == list: + if isinstance(y, list): y = k2.RaggedTensor(y).to(device) else: y = y.to(device) @@ -697,7 +694,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 5a5925d55..2052e9da7 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -61,7 +61,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" import argparse import copy import logging -import os import warnings from pathlib import Path from shutil import copyfile @@ -103,8 +102,6 @@ LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler ] -os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( @@ -684,7 +681,7 @@ def compute_loss( texts = batch["supervisions"]["text"] y = graph_compiler.texts_to_ids(texts) - if type(y) == list: + if isinstance(y, list): y = k2.RaggedTensor(y).to(device) else: y = y.to(device) diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py index fb48b6f8e..9a4e8a36f 100755 --- a/egs/yesno/ASR/local/compute_fbank_yesno.py +++ b/egs/yesno/ASR/local/compute_fbank_yesno.py @@ -47,6 +47,13 @@ def compute_fbank_yesno(): ) assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank( FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins) ) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 2e6087ad5..609e25626 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -130,6 +130,8 @@ class TensorDiagnostic(object): x = x[0] if not isinstance(x, Tensor): return + if x.numel() == 0: # for empty tensor + return x = x.detach().clone() if x.ndim == 0: x = x.unsqueeze(0)