use lr hours in librilight ssl

This commit is contained in:
yifanyeung 2024-08-19 23:14:26 +08:00
parent 70a1713662
commit d025ce11ff
2 changed files with 17 additions and 14 deletions

View File

@ -97,20 +97,20 @@ torchrun \
zipformer/pretrain.py \
--use-multi-node 1 \
--master-port $master_port \
--num-epochs 30 \
--num-epochs 20 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--max-duration 600 \
--max-duration 350 \
--quadratic-duration 1024 \
--accum-grad 1 \
--do-normalize 1 \
--mask-prob 0.8 \
--dropout-input 0.0 \
--dropout-features 0.0 \
--dropout-input 0.1 \
--dropout-features 0.1 \
--feature-grad-mult 1.0 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 768,1536,2048,3072,2048,1536 \
--encoder-dim 256,512,768,1024,768,512 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--base-lr 0.045

View File

@ -472,10 +472,10 @@ def get_parser():
)
parser.add_argument(
"--lr-epochs",
"--lr-hours",
type=float,
default=0.2,
help="""Number of epochs that affects how rapidly the learning rate decreases.
default=30000,
help="""Number of hours that affects how rapidly the learning rate decreases.
""",
)
@ -951,6 +951,10 @@ def train_one_epoch(
if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
params.batch_idx_train += 1
scheduler.step_batch(params.batch_idx_train)
# Use the number of hours of speech to adjust the learning rate
scheduler.step_epoch(
params.batch_idx_train * params.max_duration * params.world_size / 3600
)
scaler.step(optimizer)
scaler.update()
@ -1142,7 +1146,7 @@ def run(rank, world_size, args):
scheduler = Eden(
optimizer,
params.lr_batches,
params.lr_epochs,
params.lr_hours,
params.warmup_batches,
params.warmup_start,
)
@ -1246,7 +1250,6 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)