From 7c1b7b3d28b1704e40d143db63518e8f72bf2ff2 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sat, 10 Dec 2022 14:03:07 +0900 Subject: [PATCH] from local --- .../.train.py.swp | Bin 102400 -> 102400 bytes .../train.py | 36 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index f60604e5296e037e30c3c8b539556f079e16ee74..a7e3822a39c1ddd820e8c900f9ec41916afd3653 100644 GIT binary patch delta 523 zcmc)Fze_?<6bJBgFJTB0Gcwx(z(S|Vx*TIws(-_RF6_gpv!?mg+ElrBojlFRF!vJI;)Ld--T zZTaPDMmba1dn$FQCn`178-uk09IFaZsTf8sOPP^awYw18dL$X81G`RUrp^ zunRNbgl_}U734uc&Oi$k#$^$_-~l7twh$da5@ulnWN3u9X7u3*3a|wM7=V8G!ZP=8 z2c_~;nvLmGHiQRlzJIFvzlY1#efFXgo_Sd~ldYl>w~N|>aEQ<D+mJR&^&7O3I delta 352 zcmXBOEl2}#7zXh7Ww+bu?L)>6!woSQL}VN+s926I83YY8$st<&Xu^gc>~=B$o!8Xi-4(%{e3vS^Q z)?gfhpg{teh^YD~fCsI4JB~Le99Xhy`Iv$LA4|T-C%^QIpMi VR^_43N%^wQ{qkv(L-N7ru@2kFOGf|z diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index ee57ec55a..a7bea4a07 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -1064,24 +1064,24 @@ def train_one_epoch( wb.log({"train/loss": loss_info["loss"]*numel}) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) +#if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value