mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add adaption recipe for pruned_transducer_stateless7 (#1059)
* Add mux for finetune * Add comments * Fix for black * Update finetune.py
This commit is contained in:
parent
bccd20d978
commit
562bda91e4
@ -79,6 +79,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
--use-averaged-model True \
|
||||
--beam-size 4 \
|
||||
--exp-dir pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \
|
||||
--max-duration 400 \
|
||||
--decoding-method $m
|
||||
done
|
||||
|
@ -1,8 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
# Xiaoyu Yang,
|
||||
# Yifan Yang,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -20,36 +21,36 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless7/decode_gigaspeech.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless7/decode_gigaspeech.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless7/decode_gigaspeech.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless7/decode_gigaspeech.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
@ -57,10 +58,10 @@ Usage:
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless7/decode_gigaspeech.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
@ -70,10 +71,10 @@ Usage:
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless7/decode_gigaspeech.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
@ -83,10 +84,10 @@ Usage:
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless7/decode_gigaspeech.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
@ -187,7 +188,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
default="pruned_transducer_stateless7/exp_giga_finetune",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,)
|
||||
# Zengwei Yao)
|
||||
# Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang,
|
||||
# Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -20,27 +22,23 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
./pruned_transducer_stateless7/finetune.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp_giga_finetune \
|
||||
--subset S \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
--base-lr 0.005 \
|
||||
--lr-epochs 100 \
|
||||
--lr-batches 100000 \
|
||||
--bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \
|
||||
--do-finetune True \
|
||||
--use-mux True \
|
||||
--finetune-ckpt icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained.pt \
|
||||
--max-duration 500
|
||||
"""
|
||||
|
||||
|
||||
@ -59,9 +57,10 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
@ -103,7 +102,21 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
|
||||
|
||||
def add_finetune_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--do-finetune", type=str2bool, default=False)
|
||||
parser.add_argument(
|
||||
"--do-finetune",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to fine-tune.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-mux",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""
|
||||
Whether to adapt. If true, we will mix 5% of the new data
|
||||
with 95% of the original data to fine-tune.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--init-modules",
|
||||
@ -907,7 +920,11 @@ def train_one_epoch(
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
# Skip the warmup by adding a huge number to batch_count
|
||||
if params.do_finetune:
|
||||
set_batch_count(model, params.batch_idx_train + 100000)
|
||||
else:
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
@ -1104,7 +1121,12 @@ def run(rank, world_size, args):
|
||||
parameters_names=parameters_names,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
scheduler = Eden(
|
||||
optimizer=optimizer,
|
||||
lr_batches=params.lr_batches,
|
||||
lr_epochs=params.lr_epochs,
|
||||
warmup_batches=0,
|
||||
)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
@ -1129,7 +1151,15 @@ def run(rank, world_size, args):
|
||||
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
train_cuts = gigaspeech.train_cuts()
|
||||
if params.use_mux:
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_cuts = CutSet.mux(
|
||||
librispeech.train_all_shuf_cuts(),
|
||||
gigaspeech.train_cuts(),
|
||||
weights=[0.95, 0.05],
|
||||
)
|
||||
else:
|
||||
train_cuts = gigaspeech.train_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
@ -1141,9 +1171,9 @@ def run(rank, world_size, args):
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
|
Loading…
x
Reference in New Issue
Block a user