From 23e9d7edf0949f161b20fdaf37cfc7a6d659a0ba Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Fri, 9 Dec 2022 17:23:11 +0900 Subject: [PATCH] from local --- .../.train.py.swp | Bin 57344 -> 61440 bytes .../train.py | 10 +++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 7676e6c3dcfea508ccf8e0c619169e9df0b58165..0d95a93be4ce750d5069223b97cbd6d323094f14 100644 GIT binary patch delta 668 zcmajdJxjwt7zgl+so>DscCfCuwM7g?2eDAF*hNIeMa5DirdN%&Nw}na=AQrPws8~Qe3|m1GABdp%L0IMF@(I zkg3-LrR%H`c=igOlISZ7N~z-O^tac-C>`~RoY?2Sptnk0)Kd@$=@#)FqZQjfpbN9d zE<$!;8eDLW@D3DV0czc=unnWo2LWh>+ZIA>*oQ6HgcJne0l%V7vG3x;f~)pgDkNfS zEURW{hRIf*Gf7LUIU0-6-e|KoZN_G@OfzYWS~^$eB?}FmhZ^IYSeQ8J8xOXuFvE)e z8ZF6&!L)S9%o~g^sivxH(i&%q&eN(kCu#ZIj49`gEHi)SviQpXi7YNFi)J#Cl9aq9 z{gXALb6TIpxHAo^YK^wjG?PIj38tdoOZvG`Y(g^SME=;ho~V717(85%(}X;^aVb3Y z<87Y)rVDdBJ8uhkVAURu%?^79b7>VlyB%1!8U>zQfGGunvgV0&zVM zdjhc$5DNk^0}x+f0?Gq%2N2f*aV-!>0I?ttgRJ<-xH;M1& diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index 376fa999d..40a3f3183 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -676,7 +676,15 @@ def compute_loss( feature = feature.to(device) supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) + if feature.ndim == 2: + feature_lens = [] + for supervision in supervisions['cut']: + try: feature_lens.append(supervision.tracks[0].cut.recording.num_samples) + except: feature_lens.append(supervision.recording.num_samples) + feature_lens = torch.tensor(feature_lens) + + elif feature.ndim == 3: + feature_lens = supervisions["num_frames"].to(device) batch_idx_train = params.batch_idx_train warm_step = params.warm_step