From 995371ad95022fec9cb9f1e709f311a1d1d408db Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 May 2022 22:17:08 +0800 Subject: [PATCH] Move train.py changes to the right dir --- .../ASR/pruned_transducer_stateless4/train.py | 5 ----- .../ASR/pruned_transducer_stateless4b/train.py | 9 +++++++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 2eb9de4f3..a79f29c30 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -701,11 +701,6 @@ def train_one_epoch( continue cur_batch_idx = batch_idx - if params.batch_idx_train % 2000 == 0 and params.batch_idx_train > 0: - mmodel = model.module if hasattr(model, 'module') else model - mmodel.encoder.orthogonalize() - optimizer.reset() - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/train.py index 4ff69d521..2eb9de4f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/train.py @@ -66,7 +66,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from optim import Eden, Eve +from optim import Eden, Eve, Abel from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP @@ -701,6 +701,11 @@ def train_one_epoch( continue cur_batch_idx = batch_idx + if params.batch_idx_train % 2000 == 0 and params.batch_idx_train > 0: + mmodel = model.module if hasattr(model, 'module') else model + mmodel.encoder.orthogonalize() + optimizer.reset() + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -871,7 +876,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - optimizer = Eve(model.parameters(), lr=params.initial_lr) + optimizer = Abel(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)