Adjust batch count w.r.t. reference duration
This commit is contained in:
parent
5e1bf8b8ec
commit
f439399ced
@ -90,6 +90,15 @@ LRSchedulerType = Union[
|
||||
]
|
||||
|
||||
|
||||
|
||||
def get_adjusted_batch_count(
|
||||
params: AttributeDict) -> float:
|
||||
# returns the number of batches we would have used so far if we had used the reference
|
||||
# duration. This is for purposes of set_batch_count().
|
||||
return (params.batch_idx_train * params.ref_duration /
|
||||
(params.max_duration * params.world_size))
|
||||
|
||||
|
||||
def set_batch_count(
|
||||
model: Union[nn.Module, DDP], batch_count: float
|
||||
) -> None:
|
||||
@ -302,6 +311,15 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ref-duration",
|
||||
type=float,
|
||||
default=600,
|
||||
help="Reference batch duration for purposes of adjusting batch counts for setting various "
|
||||
"schedules inside the model"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -858,8 +876,8 @@ def train_one_epoch(
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
if int(params.batch_idx_train) % 10 == 0:
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
if int(params.batch_idx_train) % 10 == 1:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user