From c239417387a0dec3625741d17b2390a8b06f9398 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Thu, 2 Feb 2023 15:05:31 +0900 Subject: [PATCH] from local --- egs/aishell/ASR/conformer_ctc/.train.py.swp | Bin 45056 -> 45056 bytes egs/aishell/ASR/conformer_ctc/train.py | 29 ++++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/.train.py.swp b/egs/aishell/ASR/conformer_ctc/.train.py.swp index 9b6fa5db31b4c55500b519f2ce9991ae57c4d11f..3ea9807f7433d0c4c88e9cd0040545a6759a100f 100644 GIT binary patch delta 414 zcmb8qzbnLX9LMpubADW$?&u@~pZgjNaz6$R5yfx@2E!#OI;B`{bHfdH4DJ+z-iwsY zk%{YD76TbYVPH5vPFKpJj5bgC3%u)5?_RIr-8)lqWJ(TSJQN$BF#Um$FdmWARB@qP z@D)833H!^_lyFDMjpMIOs5etiztxI7B9A0yF^U#^*N8me8mCxB6c$R=BKNq%E|M67 z7cW&J&v?QaY#d<^^RV#h6uEKAob0pgAdVS~z=9uwk4ljf93zDohA;>pu6XAHhd4kQ z)9}MMwhL`RV~qTR{keMX)Xj}Xt&h0$cD+YkSxbZH>~m{8JFNL^vle^9+8t?E_ez&e dZiRF$7ocy@%bclOS~TH--;<`pe$Tab{s3E>Qdj^0 delta 371 zcmXBPy-PxI6vpwNw_6tWaz&I%62(dkNn4a!+5&?d8fpw`h#-uJq^T ziJ%vdyHwE7Ur=*G2u0LVI=1zV4t(J7oHLwLY-q)X=3h=NWaq+@<0(ZMN^Q>9S8i+m zy5Y(@EyK6b6D0q3{SWivExwvBol2dffJIEfgAcb-4{&gVETZ`7Aj2t6u!b4<&}dic z3YWM*1tsiY9Z9rwrLJ*|EnO|CRhkLJ(T7%>QcXPK1}1hfixG6ggEx)CaDW1~kwXkk zUi^FW%XL@_Ig!%6E_OK~nw5|-YgF8okQmhgF|BFI?FZz^(&c$CAfM@|#4}OpDkR0p b49QtK$bPSs?P2M!dgacJ=wZrvw+Catlm$^* diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index 0472d9d16..0e781f3e7 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -532,21 +532,22 @@ def train_one_epoch( ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, + #if batch_idx > 0 and batch_idx % params.valid_interval == 0: + if 1: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value