update the usage; set a very large batch count

This commit is contained in:
marcoyang 2024-02-06 17:05:58 +08:00
parent 4590da3a83
commit d32314cc1f

View File

@ -24,31 +24,32 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training: # Fine-tune without mux (i.e not mixing with original training data):
./zipformer/finetune.py \ ./zipformer/finetune.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir zipformer/exp \ --do-finetune 1 \
--full-libri 1 \ --finetune-ckpt path/to/ckpt \
--base-lr 0.0045 \
--use-mux 0 \
--exp-dir zipformer/exp_finetune \
--max-duration 1000 --max-duration 1000
# For streaming model training: # Fine-tune without mux (i.e mixing with original training data):
./zipformer/finetune.py \ ./zipformer/finetune.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir zipformer/exp \ --do-finetune 1 \
--causal 1 \ --finetune-ckpt path/to/ckpt \
--full-libri 1 \ --base-lr 0.0045 \
--use-mux 1 \
--exp-dir zipformer/exp_finetune \
--max-duration 1000 --max-duration 1000
It supports training with:
- transducer loss (default), with `--use-transducer True --use-ctc False`
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
""" """
@ -106,11 +107,13 @@ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
def get_adjusted_batch_count(params: AttributeDict) -> float: def get_adjusted_batch_count(params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference # returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count(). # duration. This is for purposes of set_batch_count().
# Note that we add a very large constant here to make the ScheduledFloat
# variable as their end value.
return ( return (
params.batch_idx_train params.batch_idx_train
* (params.max_duration * params.world_size) * (params.max_duration * params.world_size)
/ params.ref_duration / params.ref_duration
) ) + 100000
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
@ -370,22 +373,29 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=0.045, help="The base learning rate." "--base-lr",
type=float,
default=0.0045,
help="""The base learning rate.
It is set to a very small value as we are doing fine-tuning""",
) )
parser.add_argument( parser.add_argument(
"--lr-batches", "--lr-batches",
type=float, type=float,
default=7500, default=100000.0,
help="""Number of steps that affects how rapidly the learning rate help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""", decreases. It is set to a very large value here to prevent the lr from decaying too fast
during fine-tuning.""",
) )
parser.add_argument( parser.add_argument(
"--lr-epochs", "--lr-epochs",
type=float, type=float,
default=3.5, default=100.0,
help="""Number of epochs that affects how rapidly the learning rate decreases. help="""Number of epochs that affects how rapidly the learning rate decreases.
It is set to a very large value here to prevent the lr from decaying too fast
during fine-tuning.
""", """,
) )