diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 7676e6c3d..0d95a93be 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index 376fa999d..40a3f3183 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -676,7 +676,15 @@ def compute_loss( feature = feature.to(device) supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) + if feature.ndim == 2: + feature_lens = [] + for supervision in supervisions['cut']: + try: feature_lens.append(supervision.tracks[0].cut.recording.num_samples) + except: feature_lens.append(supervision.recording.num_samples) + feature_lens = torch.tensor(feature_lens) + + elif feature.ndim == 3: + feature_lens = supervisions["num_frames"].to(device) batch_idx_train = params.batch_idx_train warm_step = params.warm_step