Fix train.py for new optimizer

This commit is contained in:
Daniel Povey 2022-07-09 10:09:53 +08:00
parent 6810849058
commit 50ee414486

View File

@ -66,7 +66,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, NeutralGradient from optim import Eden, PrAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -926,9 +926,8 @@ def run(rank, world_size, args):
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = NeutralGradient(model.parameters(), optimizer = PrAdam(model.parameters(),
lr=params.initial_lr, lr=params.initial_lr)
lr_for_speedup=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)