This commit is contained in:
JinZr 2023-09-26 13:34:33 +08:00
parent 5047dd37ac
commit b8ea509067
3 changed files with 42 additions and 19 deletions

View File

@ -13,6 +13,7 @@ dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
vocab_sizes=(
500
2000
)
@ -76,25 +77,31 @@ fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare Byte BPE based lang"
mkdir -p data/fbank
if [ ! -d ../../aishell2/ASR/data/lang_char ]; then
if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then
log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
exit 1
fi
if [ ! -d ../../librispeech/ASR/data/lang_phone ]; then
if [ ! -d ../../librispeech/ASR/data/lang_phone ] && [ ! -d ./data/lang_phone ]; then
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5"
exit 1
fi
if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ]; then
if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6"
exit 1
fi
cd data/
ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) .
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
if [ ! -d ./lang_char ]; then
ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
fi
if [ ! -d ./lang_phone ]; then
ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) .
fi
if [ ! -d ./lang_bpe_500 ]; then
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
fi
cd ../
for vocab_size in ${vocab_sizes[@]}; do
@ -104,9 +111,11 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
> $lang_dir/text
./local/prepare_for_bpe_model.py \
--lang-dir ./$lang_dir \
--text $lang_dir/text
if [ ! -f $lang_dir/transcript_chars.txt ]; then
./local/prepare_for_bpe_model.py \
--lang-dir ./$lang_dir \
--text $lang_dir/text
fi
if [ ! -f $lang_dir/text_words_segmentation ]; then
python3 ./local/text2segments.py \
@ -115,6 +124,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
cat ./data/lang_bpe_500/transcript_words.txt \
>> $lang_dir/text_words_segmentation
cat ./data/lang_char/text \
>> $lang_dir/text
fi
cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
@ -130,7 +142,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
./local/train_bbpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_chars.txt
--transcript $lang_dir/text
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then

View File

@ -120,6 +120,7 @@ from lhotse.cut import Cut
from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params
from icefall import byte_encode, byte_decode
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -193,14 +194,14 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_2000/bpe.model",
default="data/lang_bbpe_2000/bbpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_2000",
default="data/lang_bbpe_2000",
help="The lang dir containing word table and LG graph",
)

View File

@ -80,7 +80,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
from icefall import diagnostics
from icefall import diagnostics, byte_encode
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -96,6 +96,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
tokenize_by_CJK_char,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1177,18 +1178,18 @@ def run(rank, world_size, args):
train_cuts = multi_dataset.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# Keep only utterances with duration between 1 second and 12 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# Caution: There is a reason to select 12.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
if c.duration < 1.0 or c.duration > 12.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
@ -1213,8 +1214,17 @@ def run(rank, world_size, args):
return True
def tokenize_and_encode_text(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = byte_encode(tokenize_by_CJK_char(text))
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(tokenize_and_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch