This commit is contained in:
JinZr 2023-09-26 13:34:33 +08:00
parent 5047dd37ac
commit b8ea509067
3 changed files with 42 additions and 19 deletions

View File

@ -13,6 +13,7 @@ dl_dir=$PWD/download
. shared/parse_options.sh || exit 1 . shared/parse_options.sh || exit 1
vocab_sizes=( vocab_sizes=(
500
2000 2000
) )
@ -76,25 +77,31 @@ fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare Byte BPE based lang" log "Stage 4: Prepare Byte BPE based lang"
mkdir -p data/fbank mkdir -p data/fbank
if [ ! -d ../../aishell2/ASR/data/lang_char ]; then if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then
log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
exit 1 exit 1
fi fi
if [ ! -d ../../librispeech/ASR/data/lang_phone ]; then if [ ! -d ../../librispeech/ASR/data/lang_phone ] && [ ! -d ./data/lang_phone ]; then
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5" log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5"
exit 1 exit 1
fi fi
if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ]; then if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6" log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6"
exit 1 exit 1
fi fi
cd data/ cd data/
ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) . if [ ! -d ./lang_char ]; then
ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) . ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . fi
if [ ! -d ./lang_phone ]; then
ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) .
fi
if [ ! -d ./lang_bpe_500 ]; then
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
fi
cd ../ cd ../
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
@ -104,9 +111,11 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \ cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
> $lang_dir/text > $lang_dir/text
./local/prepare_for_bpe_model.py \ if [ ! -f $lang_dir/transcript_chars.txt ]; then
--lang-dir ./$lang_dir \ ./local/prepare_for_bpe_model.py \
--text $lang_dir/text --lang-dir ./$lang_dir \
--text $lang_dir/text
fi
if [ ! -f $lang_dir/text_words_segmentation ]; then if [ ! -f $lang_dir/text_words_segmentation ]; then
python3 ./local/text2segments.py \ python3 ./local/text2segments.py \
@ -115,6 +124,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
cat ./data/lang_bpe_500/transcript_words.txt \ cat ./data/lang_bpe_500/transcript_words.txt \
>> $lang_dir/text_words_segmentation >> $lang_dir/text_words_segmentation
cat ./data/lang_char/text \
>> $lang_dir/text
fi fi
cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
@ -130,7 +142,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
./local/train_bbpe_model.py \ ./local/train_bbpe_model.py \
--lang-dir $lang_dir \ --lang-dir $lang_dir \
--vocab-size $vocab_size \ --vocab-size $vocab_size \
--transcript $lang_dir/transcript_chars.txt --transcript $lang_dir/text
fi fi
if [ ! -f $lang_dir/L_disambig.pt ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then

View File

@ -120,6 +120,7 @@ from lhotse.cut import Cut
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import byte_encode, byte_decode
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -193,14 +194,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
default="data/lang_bpe_2000/bpe.model", default="data/lang_bbpe_2000/bbpe.model",
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=Path,
default="data/lang_bpe_2000", default="data/lang_bbpe_2000",
help="The lang dir containing word table and LG graph", help="The lang dir containing word table and LG graph",
) )

View File

@ -80,7 +80,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
from icefall import diagnostics from icefall import diagnostics, byte_encode
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -96,6 +96,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
tokenize_by_CJK_char,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1177,18 +1178,18 @@ def run(rank, world_size, args):
train_cuts = multi_dataset.train_cuts() train_cuts = multi_dataset.train_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 12 seconds
# #
# Caution: There is a reason to select 20.0 here. Please see # Caution: There is a reason to select 12.0 here. Please see
# ../local/display_manifest_statistics.py # ../local/display_manifest_statistics.py
# #
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 20.0: if c.duration < 1.0 or c.duration > 12.0:
# logging.warning( logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) )
return False return False
# In pruned RNN-T, we require that T >= S # In pruned RNN-T, we require that T >= S
@ -1213,8 +1214,17 @@ def run(rank, world_size, args):
return True return True
def tokenize_and_encode_text(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = byte_encode(tokenize_by_CJK_char(text))
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(tokenize_and_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch # saved in the middle of an epoch