diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index cf62b33c6..843d103cc 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -24,31 +24,32 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -# For non-streaming model training: +# Fine-tune without mux (i.e not mixing with original training data): ./zipformer/finetune.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir zipformer/exp \ - --full-libri 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 0 \ + --exp-dir zipformer/exp_finetune \ --max-duration 1000 -# For streaming model training: +# Fine-tune without mux (i.e mixing with original training data): ./zipformer/finetune.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --full-libri 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 1 \ + --exp-dir zipformer/exp_finetune \ --max-duration 1000 -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` """ @@ -106,11 +107,13 @@ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler def get_adjusted_batch_count(params: AttributeDict) -> float: # returns the number of batches we would have used so far if we had used the reference # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. return ( params.batch_idx_train * (params.max_duration * params.world_size) / params.ref_duration - ) + ) + 100000 def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -370,22 +373,29 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", ) parser.add_argument( "--lr-batches", type=float, - default=7500, + default=100000.0, help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", ) parser.add_argument( "--lr-epochs", type=float, - default=3.5, + default=100.0, help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. """, )