From ce8fb6fea04f7345aebbdaf0143f054ceaa879c3 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sat, 10 Dec 2022 11:08:49 +0900 Subject: [PATCH] from local --- egs/librispeech/ASR/.run_v2.sh.swp | Bin 12288 -> 12288 bytes .../.train.py.swp | Bin 61440 -> 61440 bytes .../train.py | 64 +++++++++--------- 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/.run_v2.sh.swp b/egs/librispeech/ASR/.run_v2.sh.swp index c1ec03e7c03c47ae111b90faa71ada8d648cd9a9..17cc2732262b1eac16d197a491e91ac662f157dc 100644 GIT binary patch delta 32 mcmZojXh;xGG6?hZRj|}EU;qLE28Qb|CntxUxx7*Qr9J?c-3q7x delta 32 mcmZojXh;xGG6?hZRj|}EU;qLE28NhVlakNwxU^CHr9J?asS1|> diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 468ac8578e17bcb0f44f92b9c2c94e04a1037aa9..b5713ff2bdb9acdebd51653e38523708c4c18b8b 100644 GIT binary patch delta 658 zcmYk)Pe{{Y7{Kvoe_XD)KaSXJDQp-uC2CYESd5@4*k);T==VoZfLLp3ZzqFa4`(V+)E@AL4!@8RWn-s+04y5f8ABIZfe z>ajh$oHegV#V;Z;5u4=3r>MwB%wYoE@ZwcOL@|v%#NmezAHpKDn8H!y@Uumvfv@)C=< zgDTG8G!A39*%4`{Sl=M>29I$MS22xKIDj@JA$S)Md5ha9qkuLf;e!_*G@3-}xQ{EC z!~jzGM#o>{ZkWYw{{+VU59{+@T`uhOq_r{HUTchYHq~|ea+p zM59JlGo>CKwBmZ!*sHgU4qY_5^}W%h>!t15WoG`K9W#6NiM3hh%%IMjoAi;H)qok- zFJ@l1THAEr#e^QWGMb+9>jkTyyHQ=TLga($x3fA^%xcN*Q^(Hhf<368k8BASncs7` u|0boaBPW9HwAb9-T4kiCCEz}G@}#yJ$^Xh=KAr0q_egIRJ8JQ<)!9Fau8cPT delta 534 zcmYk(!E4fC9LMp`uiwvfrkfL*9$i8b40UZ4co~rhv0xUZ#%hP$Zpv0P*oe|UKs@wg zhde6Lp~D70I&`T)AyFO5t?TeO*UAn%L|u9d9rlIS^Zh;#JP&-oPqk}SyXMxHfDv}S z2Bs2~j45?2Ndf5&X%M{;X&*08Mhav2IV^2r1M^5Bi3EaZ`=uu8SV0QYm_i7@e9|X8 zM-2s^D!M^2j-yl3A>QH*ns|a7ESUJ|Y$J1^6H%ZS?7XD>Q4`L^MS5V2THc5Qajxpe!H>3*=Z7`eADyET`w zf6Rsab!O*fxPuQ2+oRdy&|vq6)7Q@9>;>QP3&CjQ%5r&S_2E*a`2S^R^lty@Kg6wG ADF6Tf diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index ffc257b56..7d62876a8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -719,41 +719,44 @@ def compute_loss( loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - ''' - # Compute ctc loss + info = MetricsTracker() + + if params.ctc_loss_scale > 0: + # Compute ctc loss - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - supervision_segments, token_ids = encode_supervisions( - supervisions, - subsampling_factor=params.subsampling_factor, - token_ids=token_ids, + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, ) - # Works with a BPE model - decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) - dense_fsa_vec = k2.DenseFsaVec( - ctc_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction="sum", - use_double_scores=params.use_double_scores, - ) - assert ctc_loss.requires_grad == is_training - loss += params.ctc_loss_scale * ctc_loss - ''' + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + + info["ctc_loss"] = ctc_loss.detach().cpu().item() + assert loss.requires_grad == is_training - info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = (feature_lens // params.subsampling_factor).sum().item() @@ -762,7 +765,6 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - #info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info