diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 66ef620e9..936baf658 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -632,7 +632,7 @@ class Cain(Optimizer): rev_dims_order.append(ndim-1) for i in range(dim+1, ndim): dims_order.append(i) - rev_dims_order.append(i) + rev_dims_order.append(i-1) dims_order.append(dim) # e.g. ndim=4, dim=1, dims_order=(0,2,3,1), rev_dims_order=(0,3,1,2) new_grad = grad.permute(*dims_order) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim2.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim2.py index 5cc56ebe9..ca75875a4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim2.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim2.py @@ -628,7 +628,7 @@ class Eve(Optimizer): rev_dims_order.append(ndim-1) for i in range(dim+1, ndim): dims_order.append(i) - rev_dims_order.append(i) + rev_dims_order.append(i-1) dims_order.append(dim) # e.g. ndim=4, dim=1, dims_order=(0,2,3,1), rev_dims_order=(0,3,1,2) new_grad = grad.permute(*dims_order)