mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
077719c9ab
@ -43,7 +43,7 @@ torch.set_num_interop_threads(1)
|
|||||||
|
|
||||||
|
|
||||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests/aidatatang_200zh")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
|
|
||||||
@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -50,28 +50,19 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 2: Process aidatatang_200zh"
|
log "Stage 2: Prepare musan manifest"
|
||||||
if [ ! -f data/fbank/aidatatang_200zh/.fbank.done ]; then
|
# We assume that you have downloaded the musan corpus
|
||||||
mkdir -p data/fbank/aidatatang_200zh
|
# to data/musan
|
||||||
lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh
|
if [ ! -f data/manifests/.manifests.done ]; then
|
||||||
touch data/fbank/aidatatang_200zh/.fbank.done
|
log "It may take 6 minutes"
|
||||||
|
mkdir -p data/manifests/
|
||||||
|
lhotse prepare musan $dl_dir/musan data/manifests/
|
||||||
|
touch data/manifests/.manifests.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Prepare musan manifest"
|
log "Stage 3: Compute fbank for musan"
|
||||||
# We assume that you have downloaded the musan corpus
|
|
||||||
# to data/musan
|
|
||||||
if [ ! -f data/manifests/.musan_manifests.done ]; then
|
|
||||||
log "It may take 6 minutes"
|
|
||||||
mkdir -p data/manifests
|
|
||||||
lhotse prepare musan $dl_dir/musan data/manifests
|
|
||||||
touch data/manifests/.musan_manifests.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|
||||||
log "Stage 4: Compute fbank for musan"
|
|
||||||
if [ ! -f data/fbank/.msuan.done ]; then
|
if [ ! -f data/fbank/.msuan.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_musan.py
|
./local/compute_fbank_musan.py
|
||||||
@ -79,8 +70,8 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "Stage 5: Compute fbank for aidatatang_200zh"
|
log "Stage 4: Compute fbank for aidatatang_200zh"
|
||||||
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aidatatang_200zh.py
|
./local/compute_fbank_aidatatang_200zh.py
|
||||||
@ -88,31 +79,38 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 6: Prepare char based lang"
|
log "Stage 5: Prepare char based lang"
|
||||||
lang_char_dir=data/lang_char
|
lang_char_dir=data/lang_char
|
||||||
mkdir -p $lang_char_dir
|
mkdir -p $lang_char_dir
|
||||||
|
|
||||||
# Prepare text.
|
# Prepare text.
|
||||||
grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \
|
# Note: in Linux, you can install jq with the following command:
|
||||||
| sed -e 's/["text:\t ]*//g' | sed 's/,//g' \
|
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
||||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
# 2. chmod +x ./jq
|
||||||
|
# 3. cp jq /usr/bin
|
||||||
|
if [ ! -f $lang_char_dir/text ]; then
|
||||||
|
gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \
|
||||||
|
|jq '.text' |sed -e 's/["text:\t ]*//g' | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||||
|
fi
|
||||||
# Prepare words.txt
|
# Prepare words.txt
|
||||||
grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \
|
if [ ! -f $lang_char_dir/text_words ]; then
|
||||||
| sed -e 's/["text:\t]*//g' | sed 's/,//g' \
|
gunzip -c data/manifests/aidatatang_200zh/aidatatang_supervisions_train.jsonl.gz \
|
||||||
| ./local/text2token.py -t "char" > $lang_char_dir/text_words
|
| jq '.text' | sed -e 's/["text:\t]*//g' | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $lang_char_dir/text_words
|
||||||
|
fi
|
||||||
|
|
||||||
cat $lang_char_dir/text_words | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
|
cat $lang_char_dir/text_words | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
|
||||||
| uniq > $lang_char_dir/words_no_ids.txt
|
| uniq > $lang_char_dir/words_no_ids.txt
|
||||||
|
|
||||||
if [ ! -f $lang_char_dir/words.txt ]; then
|
if [ ! -f $lang_char_dir/words.txt ]; then
|
||||||
./local/prepare_words.py \
|
./local/prepare_words.py \
|
||||||
--input-file $lang_char_dir/words_no_ids.txt
|
--input-file $lang_char_dir/words_no_ids.txt \
|
||||||
--output-file $lang_char_dir/words.txt
|
--output-file $lang_char_dir/words.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||||
./local/prepare_char.py
|
./local/prepare_char.py
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -522,63 +522,14 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
# Note: Please use "pip install webdataset==0.1.103"
|
|
||||||
# for installing the webdataset.
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
|
|
||||||
from lhotse import CutSet
|
|
||||||
from lhotse.dataset.webdataset import export_to_webdataset
|
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
||||||
|
|
||||||
dev = "dev"
|
dev_cuts = aidatatang_200zh.valid_cuts()
|
||||||
test = "test"
|
test_cuts = aidatatang_200zh.test_cuts()
|
||||||
|
dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts)
|
||||||
if not os.path.exists(f"{dev}/shared-0.tar"):
|
test_dl = aidatatang_200zh.test_dataloaders(test_cuts)
|
||||||
os.makedirs(dev)
|
|
||||||
dev_cuts = aidatatang_200zh.valid_cuts()
|
|
||||||
export_to_webdataset(
|
|
||||||
dev_cuts,
|
|
||||||
output_path=f"{dev}/shared-%d.tar",
|
|
||||||
shard_size=300,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not os.path.exists(f"{test}/shared-0.tar"):
|
|
||||||
os.makedirs(test)
|
|
||||||
test_cuts = aidatatang_200zh.test_cuts()
|
|
||||||
export_to_webdataset(
|
|
||||||
test_cuts,
|
|
||||||
output_path=f"{test}/shared-%d.tar",
|
|
||||||
shard_size=300,
|
|
||||||
)
|
|
||||||
|
|
||||||
dev_shards = [
|
|
||||||
str(path)
|
|
||||||
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
|
||||||
]
|
|
||||||
cuts_dev_webdataset = CutSet.from_webdataset(
|
|
||||||
dev_shards,
|
|
||||||
split_by_worker=True,
|
|
||||||
split_by_node=True,
|
|
||||||
shuffle_shards=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_shards = [
|
|
||||||
str(path)
|
|
||||||
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
|
|
||||||
]
|
|
||||||
cuts_test_webdataset = CutSet.from_webdataset(
|
|
||||||
test_shards,
|
|
||||||
split_by_worker=True,
|
|
||||||
split_by_node=True,
|
|
||||||
shuffle_shards=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset)
|
|
||||||
test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
|
|
||||||
|
|
||||||
test_sets = ["dev", "test"]
|
test_sets = ["dev", "test"]
|
||||||
test_dl = [dev_dl, test_dl]
|
test_dl = [dev_dl, test_dl]
|
||||||
|
@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -62,6 +62,13 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -62,6 +62,13 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -63,6 +63,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -43,7 +43,7 @@ torch.set_num_interop_threads(1)
|
|||||||
|
|
||||||
|
|
||||||
def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests/alimeeting")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
|
|
||||||
@ -63,6 +63,13 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -30,9 +30,11 @@ with word segmenting:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import paddle
|
||||||
import jieba
|
import jieba
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
paddle.enable_static()
|
||||||
jieba.enable_paddle()
|
jieba.enable_paddle()
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
# Prepare text.
|
# Prepare text.
|
||||||
# Note: in Linux, you can install jq with the following command:
|
# Note: in Linux, you can install jq with the following command:
|
||||||
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
||||||
gunzip -c data/manifests/alimeeting/supervisions_train.jsonl.gz \
|
gunzip -c data/manifests/alimeeting/alimeeting_supervisions_train.jsonl.gz \
|
||||||
| jq ".text" | sed 's/"//g' \
|
| jq ".text" | sed 's/"//g' \
|
||||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||||
|
|
||||||
|
@ -62,6 +62,13 @@ def preprocess_giga_speech():
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
|
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
|
||||||
|
@ -81,9 +81,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" ==
|
|||||||
# or
|
# or
|
||||||
# pip install multi_quantization
|
# pip install multi_quantization
|
||||||
|
|
||||||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('multi_quantization') is not None)")
|
||||||
if [ $has_quantization == 'False' ]; then
|
if [ $has_quantization == 'False' ]; then
|
||||||
log "Please install quantization before running following stages"
|
log "Please install multi_quantization before running following stages"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -66,6 +66,13 @@ def compute_fbank_librispeech():
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -65,6 +65,8 @@ def compute_fbank_musan():
|
|||||||
assert len(manifests) == len(dataset_parts), (
|
assert len(manifests) == len(dataset_parts), (
|
||||||
len(manifests),
|
len(manifests),
|
||||||
len(dataset_parts),
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
|
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
|
||||||
|
@ -68,6 +68,13 @@ def preprocess_giga_speech():
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
|
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
|
||||||
|
@ -164,6 +164,10 @@ class Eve(Optimizer):
|
|||||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
|
# Constrain the range of scalar weights
|
||||||
|
if p.numel() == 1:
|
||||||
|
p.clamp_(min=-10, max=2)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -652,13 +652,13 @@ def main():
|
|||||||
|
|
||||||
# Also export encoder/decoder/joiner separately
|
# Also export encoder/decoder/joiner separately
|
||||||
encoder_filename = params.exp_dir / "encoder_jit_script.pt"
|
encoder_filename = params.exp_dir / "encoder_jit_script.pt"
|
||||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
export_encoder_model_jit_script(model.encoder, encoder_filename)
|
||||||
|
|
||||||
decoder_filename = params.exp_dir / "decoder_jit_script.pt"
|
decoder_filename = params.exp_dir / "decoder_jit_script.pt"
|
||||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
export_decoder_model_jit_script(model.decoder, decoder_filename)
|
||||||
|
|
||||||
joiner_filename = params.exp_dir / "joiner_jit_script.pt"
|
joiner_filename = params.exp_dir / "joiner_jit_script.pt"
|
||||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
export_joiner_model_jit_script(model.joiner, joiner_filename)
|
||||||
|
|
||||||
elif params.jit_trace is True:
|
elif params.jit_trace is True:
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
@ -181,7 +181,7 @@ def test_convert_scaled_to_non_scaled():
|
|||||||
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
|
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
|
||||||
|
|
||||||
d1 = model.decoder(y)
|
d1 = model.decoder(y)
|
||||||
d2 = model.decoder(y)
|
d2 = converted_model.decoder(y)
|
||||||
|
|
||||||
assert torch.allclose(d1, d2)
|
assert torch.allclose(d1, d2)
|
||||||
|
|
||||||
|
@ -81,18 +81,17 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
# hyps is a list, every element is decode result of a sentence.
|
||||||
hyps = hubert_model.ctc_greedy_search(batch)
|
hyps = hubert_model.ctc_greedy_search(batch)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
assert len(hyps) == len(texts)
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
this_batch = []
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
for hyp_text, ref_text in zip(hyps, texts):
|
for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
hyp_words = hyp_text.split()
|
hyp_words = hyp_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
results["ctc_greedy_search"].extend(this_batch)
|
results["ctc_greedy_search"].extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
|
@ -28,7 +28,7 @@ from typing import List, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import quantization
|
import multi_quantization as quantization
|
||||||
|
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from hubert_xlarge import HubertXlargeFineTuned
|
from hubert_xlarge import HubertXlargeFineTuned
|
||||||
|
@ -69,6 +69,13 @@ def compute_fbank_musan():
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
musan_cuts_path = src_dir / "cuts_musan.jsonl.gz"
|
musan_cuts_path = src_dir / "cuts_musan.jsonl.gz"
|
||||||
|
|
||||||
if musan_cuts_path.is_file():
|
if musan_cuts_path.is_file():
|
||||||
|
@ -62,6 +62,13 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -62,6 +62,13 @@ def compute_fbank_tedlium():
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -63,6 +63,13 @@ def compute_fbank_timit():
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
@ -23,6 +23,8 @@ from pathlib import Path
|
|||||||
from lhotse import CutSet, SupervisionSegment
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall import setup_logger
|
||||||
|
|
||||||
# Similar text filtering and normalization procedure as in:
|
# Similar text filtering and normalization procedure as in:
|
||||||
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
|
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
|
||||||
|
|
||||||
@ -48,13 +50,17 @@ def preprocess_wenet_speech():
|
|||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
output_dir.mkdir(exist_ok=True)
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Note: By default, we preprocess all sub-parts.
|
||||||
|
# You can delete those that you don't need.
|
||||||
|
# For instance, if you don't want to use the L subpart, just remove
|
||||||
|
# the line below containing "L"
|
||||||
dataset_parts = (
|
dataset_parts = (
|
||||||
"L",
|
|
||||||
"M",
|
|
||||||
"S",
|
|
||||||
"DEV",
|
"DEV",
|
||||||
"TEST_NET",
|
"TEST_NET",
|
||||||
"TEST_MEETING",
|
"TEST_MEETING",
|
||||||
|
"S",
|
||||||
|
"M",
|
||||||
|
"L",
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Loading manifest (may take 10 minutes)")
|
logging.info("Loading manifest (may take 10 minutes)")
|
||||||
@ -66,6 +72,13 @@ def preprocess_wenet_speech():
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
|
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
|
||||||
@ -81,10 +94,13 @@ def preprocess_wenet_speech():
|
|||||||
logging.info(f"Normalizing text in {partition}")
|
logging.info(f"Normalizing text in {partition}")
|
||||||
for sup in m["supervisions"]:
|
for sup in m["supervisions"]:
|
||||||
text = str(sup.text)
|
text = str(sup.text)
|
||||||
logging.info(f"Original text: {text}")
|
orig_text = text
|
||||||
sup.text = normalize_text(sup.text)
|
sup.text = normalize_text(sup.text)
|
||||||
text = str(sup.text)
|
text = str(sup.text)
|
||||||
logging.info(f"Normalize text: {text}")
|
if len(orig_text) != len(text):
|
||||||
|
logging.info(
|
||||||
|
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create long-recording cut manifests.
|
# Create long-recording cut manifests.
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
@ -109,12 +125,10 @@ def preprocess_wenet_speech():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
formatter = (
|
setup_logger(log_filename="./log-preprocess-wenetspeech")
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
|
||||||
preprocess_wenet_speech()
|
preprocess_wenet_speech()
|
||||||
|
logging.info("Done")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -81,7 +81,6 @@ For training with the S subset:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -120,8 +119,6 @@ LRSchedulerType = Union[
|
|||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
]
|
]
|
||||||
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -162,7 +159,7 @@ def get_parser():
|
|||||||
default=0,
|
default=0,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from from this epoch.
|
||||||
If it is positive, it will load checkpoint from
|
If it is positive, it will load checkpoint from
|
||||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -361,8 +358,8 @@ def get_params() -> AttributeDict:
|
|||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 10,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 1,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
@ -545,7 +542,7 @@ def compute_loss(
|
|||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute RNN-T loss given the model and its inputs.
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
@ -573,7 +570,7 @@ def compute_loss(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
y = graph_compiler.texts_to_ids(texts)
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
if type(y) == list:
|
if isinstance(y, list):
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
else:
|
else:
|
||||||
y = y.to(device)
|
y = y.to(device)
|
||||||
@ -697,7 +694,6 @@ def train_one_epoch(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
@ -61,7 +61,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -103,8 +102,6 @@ LRSchedulerType = Union[
|
|||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
]
|
]
|
||||||
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -684,7 +681,7 @@ def compute_loss(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
y = graph_compiler.texts_to_ids(texts)
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
if type(y) == list:
|
if isinstance(y, list):
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
else:
|
else:
|
||||||
y = y.to(device)
|
y = y.to(device)
|
||||||
|
@ -47,6 +47,13 @@ def compute_fbank_yesno():
|
|||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
extractor = Fbank(
|
extractor = Fbank(
|
||||||
FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
|
FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
|
||||||
)
|
)
|
||||||
|
@ -130,6 +130,8 @@ class TensorDiagnostic(object):
|
|||||||
x = x[0]
|
x = x[0]
|
||||||
if not isinstance(x, Tensor):
|
if not isinstance(x, Tensor):
|
||||||
return
|
return
|
||||||
|
if x.numel() == 0: # for empty tensor
|
||||||
|
return
|
||||||
x = x.detach().clone()
|
x = x.detach().clone()
|
||||||
if x.ndim == 0:
|
if x.ndim == 0:
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user