mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
updated
This commit is contained in:
parent
5047dd37ac
commit
b8ea509067
@ -13,6 +13,7 @@ dl_dir=$PWD/download
|
|||||||
. shared/parse_options.sh || exit 1
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
vocab_sizes=(
|
vocab_sizes=(
|
||||||
|
500
|
||||||
2000
|
2000
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -76,25 +77,31 @@ fi
|
|||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "Stage 4: Prepare Byte BPE based lang"
|
log "Stage 4: Prepare Byte BPE based lang"
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
if [ ! -d ../../aishell2/ASR/data/lang_char ]; then
|
if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then
|
||||||
log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
|
log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -d ../../librispeech/ASR/data/lang_phone ]; then
|
if [ ! -d ../../librispeech/ASR/data/lang_phone ] && [ ! -d ./data/lang_phone ]; then
|
||||||
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5"
|
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ]; then
|
if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then
|
||||||
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6"
|
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd data/
|
cd data/
|
||||||
ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
|
if [ ! -d ./lang_char ]; then
|
||||||
ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) .
|
ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
|
||||||
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
|
fi
|
||||||
|
if [ ! -d ./lang_phone ]; then
|
||||||
|
ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) .
|
||||||
|
fi
|
||||||
|
if [ ! -d ./lang_bpe_500 ]; then
|
||||||
|
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
|
||||||
|
fi
|
||||||
cd ../
|
cd ../
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
@ -104,9 +111,11 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
|
cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
|
||||||
> $lang_dir/text
|
> $lang_dir/text
|
||||||
|
|
||||||
./local/prepare_for_bpe_model.py \
|
if [ ! -f $lang_dir/transcript_chars.txt ]; then
|
||||||
--lang-dir ./$lang_dir \
|
./local/prepare_for_bpe_model.py \
|
||||||
--text $lang_dir/text
|
--lang-dir ./$lang_dir \
|
||||||
|
--text $lang_dir/text
|
||||||
|
fi
|
||||||
|
|
||||||
if [ ! -f $lang_dir/text_words_segmentation ]; then
|
if [ ! -f $lang_dir/text_words_segmentation ]; then
|
||||||
python3 ./local/text2segments.py \
|
python3 ./local/text2segments.py \
|
||||||
@ -115,6 +124,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
|
|
||||||
cat ./data/lang_bpe_500/transcript_words.txt \
|
cat ./data/lang_bpe_500/transcript_words.txt \
|
||||||
>> $lang_dir/text_words_segmentation
|
>> $lang_dir/text_words_segmentation
|
||||||
|
|
||||||
|
cat ./data/lang_char/text \
|
||||||
|
>> $lang_dir/text
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
|
cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
|
||||||
@ -130,7 +142,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
./local/train_bbpe_model.py \
|
./local/train_bbpe_model.py \
|
||||||
--lang-dir $lang_dir \
|
--lang-dir $lang_dir \
|
||||||
--vocab-size $vocab_size \
|
--vocab-size $vocab_size \
|
||||||
--transcript $lang_dir/transcript_chars.txt
|
--transcript $lang_dir/text
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||||
|
@ -120,6 +120,7 @@ from lhotse.cut import Cut
|
|||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall import byte_encode, byte_decode
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -193,14 +194,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_2000/bpe.model",
|
default="data/lang_bbpe_2000/bbpe.model",
|
||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--lang-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="data/lang_bpe_2000",
|
default="data/lang_bbpe_2000",
|
||||||
help="The lang dir containing word table and LG graph",
|
help="The lang dir containing word table and LG graph",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics, byte_encode
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -96,6 +96,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
tokenize_by_CJK_char,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -1177,18 +1178,18 @@ def run(rank, world_size, args):
|
|||||||
train_cuts = multi_dataset.train_cuts()
|
train_cuts = multi_dataset.train_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 12 seconds
|
||||||
#
|
#
|
||||||
# Caution: There is a reason to select 20.0 here. Please see
|
# Caution: There is a reason to select 12.0 here. Please see
|
||||||
# ../local/display_manifest_statistics.py
|
# ../local/display_manifest_statistics.py
|
||||||
#
|
#
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 1.0 or c.duration > 20.0:
|
if c.duration < 1.0 or c.duration > 12.0:
|
||||||
# logging.warning(
|
logging.warning(
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
# In pruned RNN-T, we require that T >= S
|
||||||
@ -1213,8 +1214,17 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def tokenize_and_encode_text(c: Cut):
|
||||||
|
# Text normalize for each sample
|
||||||
|
text = c.supervisions[0].text
|
||||||
|
text = byte_encode(tokenize_by_CJK_char(text))
|
||||||
|
c.supervisions[0].text = text
|
||||||
|
return c
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
|
train_cuts = train_cuts.map(tokenize_and_encode_text)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
# saved in the middle of an epoch
|
# saved in the middle of an epoch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user