Add adaption recipe for pruned_transducer_stateless7 (#1059)

* Add mux for finetune

* Add comments

* Fix for black

* Update finetune.py
This commit is contained in:
Yifan Yang 2023-05-17 16:02:27 +08:00 committed by GitHub
parent bccd20d978
commit 562bda91e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 46 deletions

View File

@ -79,6 +79,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
--use-averaged-model True \ --use-averaged-model True \
--beam-size 4 \ --beam-size 4 \
--exp-dir pruned_transducer_stateless7/exp_giga_finetune \ --exp-dir pruned_transducer_stateless7/exp_giga_finetune \
--bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \
--max-duration 400 \ --max-duration 400 \
--decoding-method $m --decoding-method $m
done done

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao, # Zengwei Yao,
# Xiaoyu Yang) # Xiaoyu Yang,
# Yifan Yang,)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -20,36 +21,36 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (one best) (4) fast beam search (one best)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 20.0 \ --beam 20.0 \
@ -57,10 +58,10 @@ Usage:
--max-states 64 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest \ --decoding-method fast_beam_search_nbest \
--beam 20.0 \ --beam 20.0 \
@ -70,10 +71,10 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(6) fast beam search (nbest oracle WER) (6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \ --decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \ --beam 20.0 \
@ -83,10 +84,10 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG) (7) fast beam search (with LG)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \ --decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \ --beam 20.0 \
@ -187,7 +188,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless7/exp", default="pruned_transducer_stateless7/exp_giga_finetune",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -1,8 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang, # Wei Kang,
# Mingshuang Luo,) # Mingshuang Luo,
# Zengwei Yao) # Zengwei Yao,
# Xiaoyu Yang,
# Yifan Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -20,27 +22,23 @@
""" """
Usage: Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1"
./pruned_transducer_stateless7/train.py \ ./pruned_transducer_stateless7/finetune.py \
--world-size 4 \ --world-size 2 \
--num-epochs 30 \ --num-epochs 20 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp_giga_finetune \
--subset S \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \ --base-lr 0.005 \
--full-libri 1 \ --lr-epochs 100 \
--max-duration 550 --lr-batches 100000 \
--bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \
--do-finetune True \
--use-mux True \
--finetune-ckpt icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained.pt \
--max-duration 500
""" """
@ -59,9 +57,10 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule from gigaspeech import GigaSpeechAsrDataModule
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
@ -103,7 +102,21 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
def add_finetune_arguments(parser: argparse.ArgumentParser): def add_finetune_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--do-finetune", type=str2bool, default=False) parser.add_argument(
"--do-finetune",
type=str2bool,
default=True,
help="Whether to fine-tune.",
)
parser.add_argument(
"--use-mux",
type=str2bool,
default=False,
help="""
Whether to adapt. If true, we will mix 5% of the new data
with 95% of the original data to fine-tune.
""",
)
parser.add_argument( parser.add_argument(
"--init-modules", "--init-modules",
@ -907,7 +920,11 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train) # Skip the warmup by adding a huge number to batch_count
if params.do_finetune:
set_batch_count(model, params.batch_idx_train + 100000)
else:
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
@ -1104,7 +1121,12 @@ def run(rank, world_size, args):
parameters_names=parameters_names, parameters_names=parameters_names,
) )
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(
optimizer=optimizer,
lr_batches=params.lr_batches,
lr_epochs=params.lr_epochs,
warmup_batches=0,
)
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict") logging.info("Loading optimizer state dict")
@ -1129,7 +1151,15 @@ def run(rank, world_size, args):
gigaspeech = GigaSpeechAsrDataModule(args) gigaspeech = GigaSpeechAsrDataModule(args)
train_cuts = gigaspeech.train_cuts() if params.use_mux:
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = CutSet.mux(
librispeech.train_all_shuf_cuts(),
gigaspeech.train_cuts(),
weights=[0.95, 0.05],
)
else:
train_cuts = gigaspeech.train_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
@ -1141,9 +1171,9 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 20.0: if c.duration < 1.0 or c.duration > 20.0:
logging.warning( # logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
) # )
return False return False
# In pruned RNN-T, we require that T >= S # In pruned RNN-T, we require that T >= S