mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
updated
This commit is contained in:
parent
5047dd37ac
commit
b8ea509067
@ -13,6 +13,7 @@ dl_dir=$PWD/download
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
vocab_sizes=(
|
||||
500
|
||||
2000
|
||||
)
|
||||
|
||||
@ -76,25 +77,31 @@ fi
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Prepare Byte BPE based lang"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -d ../../aishell2/ASR/data/lang_char ]; then
|
||||
if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then
|
||||
log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d ../../librispeech/ASR/data/lang_phone ]; then
|
||||
if [ ! -d ../../librispeech/ASR/data/lang_phone ] && [ ! -d ./data/lang_phone ]; then
|
||||
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ]; then
|
||||
if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then
|
||||
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd data/
|
||||
ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
|
||||
ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) .
|
||||
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
|
||||
if [ ! -d ./lang_char ]; then
|
||||
ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
|
||||
fi
|
||||
if [ ! -d ./lang_phone ]; then
|
||||
ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) .
|
||||
fi
|
||||
if [ ! -d ./lang_bpe_500 ]; then
|
||||
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
|
||||
fi
|
||||
cd ../
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
@ -104,9 +111,11 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
|
||||
> $lang_dir/text
|
||||
|
||||
./local/prepare_for_bpe_model.py \
|
||||
--lang-dir ./$lang_dir \
|
||||
--text $lang_dir/text
|
||||
if [ ! -f $lang_dir/transcript_chars.txt ]; then
|
||||
./local/prepare_for_bpe_model.py \
|
||||
--lang-dir ./$lang_dir \
|
||||
--text $lang_dir/text
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/text_words_segmentation ]; then
|
||||
python3 ./local/text2segments.py \
|
||||
@ -115,6 +124,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
|
||||
cat ./data/lang_bpe_500/transcript_words.txt \
|
||||
>> $lang_dir/text_words_segmentation
|
||||
|
||||
cat ./data/lang_char/text \
|
||||
>> $lang_dir/text
|
||||
fi
|
||||
|
||||
cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
|
||||
@ -130,7 +142,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
./local/train_bbpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_chars.txt
|
||||
--transcript $lang_dir/text
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
|
@ -120,6 +120,7 @@ from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import byte_encode, byte_decode
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -193,14 +194,14 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_2000/bpe.model",
|
||||
default="data/lang_bbpe_2000/bbpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_2000",
|
||||
default="data/lang_bbpe_2000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
|
@ -80,7 +80,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall import diagnostics, byte_encode
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
@ -96,6 +96,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
tokenize_by_CJK_char,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1177,18 +1178,18 @@ def run(rank, world_size, args):
|
||||
train_cuts = multi_dataset.train_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
# Keep only utterances with duration between 1 second and 12 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# Caution: There is a reason to select 12.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
if c.duration < 1.0 or c.duration > 12.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
@ -1213,8 +1214,17 @@ def run(rank, world_size, args):
|
||||
|
||||
return True
|
||||
|
||||
def tokenize_and_encode_text(c: Cut):
|
||||
# Text normalize for each sample
|
||||
text = c.supervisions[0].text
|
||||
text = byte_encode(tokenize_by_CJK_char(text))
|
||||
c.supervisions[0].text = text
|
||||
return c
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
train_cuts = train_cuts.map(tokenize_and_encode_text)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
# saved in the middle of an epoch
|
||||
|
Loading…
x
Reference in New Issue
Block a user