diff --git a/kaldifeat/python/kaldifeat/offline_feature.py b/kaldifeat/python/kaldifeat/offline_feature.py index 0d2a8c3..dc07ee0 100644 --- a/kaldifeat/python/kaldifeat/offline_feature.py +++ b/kaldifeat/python/kaldifeat/offline_feature.py @@ -69,7 +69,7 @@ class OfflineFeature(nn.Module): for w in waves ] - strided = [self.convert_samples_to_frames(w) for w in waves] + strided = [self.convert_samples_to_frames(w.to(self.opts.device)) for w in waves] strided = torch.cat(strided, dim=0) features = self.compute(strided, vtln_warp)