Adjust batch count w.r.t. reference duration

This commit is contained in:
Daniel Povey 2022-12-18 14:25:23 +08:00
parent 5e1bf8b8ec
commit f439399ced

View File

@ -90,6 +90,15 @@ LRSchedulerType = Union[
]
def get_adjusted_batch_count(
params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count().
return (params.batch_idx_train * params.ref_duration /
(params.max_duration * params.world_size))
def set_batch_count(
model: Union[nn.Module, DDP], batch_count: float
) -> None:
@ -302,6 +311,15 @@ def get_parser():
""",
)
parser.add_argument(
"--ref-duration",
type=float,
default=600,
help="Reference batch duration for purposes of adjusting batch counts for setting various "
"schedules inside the model"
)
parser.add_argument(
"--context-size",
type=int,
@ -858,8 +876,8 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
if int(params.batch_idx_train) % 10 == 0:
set_batch_count(model, params.batch_idx_train)
if int(params.batch_idx_train) % 10 == 1:
set_batch_count(model, get_adjusted_batch_count(params))
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)