Move train.py changes to the right dir

This commit is contained in:
Daniel Povey 2022-05-15 22:17:08 +08:00
parent cee5396058
commit 995371ad95
2 changed files with 7 additions and 7 deletions

View File

@ -701,11 +701,6 @@ def train_one_epoch(
continue continue
cur_batch_idx = batch_idx cur_batch_idx = batch_idx
if params.batch_idx_train % 2000 == 0 and params.batch_idx_train > 0:
mmodel = model.module if hasattr(model, 'module') else model
mmodel.encoder.orthogonalize()
optimizer.reset()
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])

View File

@ -66,7 +66,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve, Abel
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -701,6 +701,11 @@ def train_one_epoch(
continue continue
cur_batch_idx = batch_idx cur_batch_idx = batch_idx
if params.batch_idx_train % 2000 == 0 and params.batch_idx_train > 0:
mmodel = model.module if hasattr(model, 'module') else model
mmodel.encoder.orthogonalize()
optimizer.reset()
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -871,7 +876,7 @@ def run(rank, world_size, args):
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = Eve(model.parameters(), lr=params.initial_lr) optimizer = Abel(model.parameters(), lr=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)