From 1e29fe07196ea325db60b3e3a92c1ab2a34f2b28 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Mon, 9 Jan 2023 20:26:14 +0900 Subject: [PATCH] from local --- .../incremental_transf/.identity_train.py.swp | Bin 65536 -> 73728 bytes .../ASR/incremental_transf/identity_train.py | 33 +++++------------- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp b/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp index fd33ee62242c219f9054f10024e2aa2af9dd40ad..65b825032275da595035268d43164405afd2ad5b 100644 GIT binary patch delta 1720 zcmZA1ZERCj7{Kv!yRwb7U0ZOBqTB9ZVPzZRiplCCi*GX-Bsdb{1jWqZ+F%@PP?TZR z7EI7=0oNT-BJ#2cbi|O*-8dbSuu0`PO z1?J)J3?aV7ezc<&bCH20ce;pE_yFq8fAI|;<7HG~1^(f8_aTgXa666n;Po_7DzNtsB! z<3dkC-sUi&=5ScA(l<9Z2Alk;CAC~!Tx{mYf7dc*CeD*L(<3jqg?I}pJA89$#I4vN z&aK#irxAb~7qW%ejY{}1>=NR0G~rP=@U>HjW7vfN?uCIrsoOBxQR@&9@i-UvBY+C{ zFqI|51U|uDY)3sR;YTqtkdCh?<|%YyCqh__aunef1s%l*`tdd-Ho(Y?2=Omf9mNPf z#Q@IWINGoYk6`f}Fc7ysQYMi0sY8O;Mt&N}uhp`f+ScpQnah2*Hz)L8=1XX_rb$9~v zag`z;#9q{6nL}{o^&=G>$}%$0fZ zZx(fDYN^C+B2m@kT_vtXWsNaS!Y*PIA0dY4u^1C1N+JxEI322m4j~8wQz}u2p&f0g zMInA5S0~Yp2ae#R=Q5BT*4sYcoVN-ElN>> zx%ic~Num#Jcv7`cDHkPB56>Ljg=yO5dwhoX@e0;qDQuXeUCu$Z(GE0XJ?_JD{6VYq zA%+g@K`T~4J!JFX#Lw&5sN;ASZ{R7c!rgG;3N7;m+OQMVSdQzI>l*eWidKXW#AB$o M#fF;AWMaVeA6}h7-2eap delta 890 zcmX}qT}YEr9LMqh%5typv2I7Kul(2K|*x`{T_v@OQw%L;m9iz2)z zmOa)YA{5f~0@YeUWbaZWM1>KCkzGVVKblHfE5l=5P8i_g}}yik{H+zb1$4?Ez8123&2N6~;v zOj<;qBZ`f%;2l35!x8wg4nF*(dy_bUEodu|gw!+fqYjl2%odB>MGAY-js~nk4a_iM z+AMM%LpXpCHp2@ye$(YQn8G;D;4}j8;jb~l2ef?(4{#s%kil{6#s*YD@Rk@{!71!Q z9b71dg7-vX2ImogA9e6R#R3uejHkGTlR<_;ewwfrMOYv%g&f_+Vf4U<6__J7GZ?}E n8c_`uR#=dyorRWuLvA*kwWOmJiZMO!QtYPxdCA>)*w*|H7>=D0 diff --git a/egs/librispeech/ASR/incremental_transf/identity_train.py b/egs/librispeech/ASR/incremental_transf/identity_train.py index 66e59465c..45a5d8e08 100755 --- a/egs/librispeech/ASR/incremental_transf/identity_train.py +++ b/egs/librispeech/ASR/incremental_transf/identity_train.py @@ -665,6 +665,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): + ''' simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, @@ -678,36 +679,18 @@ def compute_loss( ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] + ''' + mse_loss = model( + x=feature, + x_lens=feautre_lens, + ) - # If the batch contains more than 10 utterances AND - # if either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + #loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + loss = mse_loss assert loss.requires_grad == is_training