diff --git a/kaldifeat/python/kaldifeat/offline_feature.py b/kaldifeat/python/kaldifeat/offline_feature.py index 68948a0..d12550c 100644 --- a/kaldifeat/python/kaldifeat/offline_feature.py +++ b/kaldifeat/python/kaldifeat/offline_feature.py @@ -42,6 +42,8 @@ class OfflineFeature(nn.Module): Kaldi, you should scale the samples to [-32768, 32767] before calling this function. Note: You are not required to scale them if you don't care about the compatibility with Kaldi. + **Note:** It does not have to be on the same device as + `self.opts.device`. vtln_warp The VTLN warping factor that the user wants to be applied when computing features for this utterance. Will normally be 1.0, @@ -57,6 +59,8 @@ class OfflineFeature(nn.Module): input is a list of 1-D tensors. The returned list has as many elements as the input list. Return a single 2-D tensor if the input is a single tensor. + Note: The returned `features` is on the same device as the input + waves. """ if isinstance(waves, list): is_list = True @@ -105,10 +109,15 @@ class OfflineFeature(nn.Module): Return a 2-D tensor with as many rows as the input tensor. Its number of columns is the number mel bins. """ + x_device = x.device + self_device = self.opts.device assert x.ndim == 2 assert x.dtype == torch.float32 if chunk_size is None: - features = self.computer.compute_features(x, vtln_warp) + features = self.computer.compute_features( + x.to(self_device), vtln_warp + ) + features = features.to(x_device) else: assert chunk_size > 0 num_chunks = x.size(0) // chunk_size @@ -118,12 +127,14 @@ class OfflineFeature(nn.Module): start = i * chunk_size end = start + chunk_size this_chunk = self.computer.compute_features( - x[start:end], vtln_warp + x[start:end].to(self_device), vtln_warp ) - features.append(this_chunk) + features.append(this_chunk.to(x_device)) if end < x.size(0): - last_chunk = self.computer.compute_features(x[end:], vtln_warp) - features.append(last_chunk) + last_chunk = self.computer.compute_features( + x[end:].to(self_device), vtln_warp + ) + features.append(last_chunk.to(x_device)) features = torch.cat(features, dim=0) return features diff --git a/kaldifeat/python/tests/test_fbank.py b/kaldifeat/python/tests/test_fbank.py index 9e6f8a3..57092b6 100755 --- a/kaldifeat/python/tests/test_fbank.py +++ b/kaldifeat/python/tests/test_fbank.py @@ -22,10 +22,16 @@ def test_fbank_default(): opts.frame_opts.dither = 0 fbank = kaldifeat.Fbank(opts) filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) + wave = read_wave(filename) features = fbank(wave) + assert features.device.type == "cpu" gt = read_ark_txt(cur_dir / "test_data/test.txt") + assert torch.allclose(features, gt, rtol=1e-1) + + wave = wave.to(device) + features = fbank(wave) + assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) @@ -41,10 +47,16 @@ def test_fbank_htk(): fbank = kaldifeat.Fbank(opts) filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) + wave = read_wave(filename) features = fbank(wave) + assert features.device.type == "cpu" gt = read_ark_txt(cur_dir / "test_data/test-htk.txt") + assert torch.allclose(features, gt, rtol=1e-1) + + wave = wave.to(device) + features = fbank(wave) + assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) @@ -59,10 +71,16 @@ def test_fbank_with_energy(): fbank = kaldifeat.Fbank(opts) filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) + wave = read_wave(filename) features = fbank(wave) gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt") + assert torch.allclose(features, gt, rtol=1e-1) + assert features.device.type == "cpu" + + wave = wave.to(device) + features = fbank(wave) + assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) @@ -77,10 +95,16 @@ def test_fbank_40_bins(): fbank = kaldifeat.Fbank(opts) filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) + wave = read_wave(filename) features = fbank(wave) + assert features.device.type == "cpu" gt = read_ark_txt(cur_dir / "test_data/test-40.txt") + assert torch.allclose(features, gt, rtol=1e-1) + + wave = wave.to(device) + features = fbank(wave) + assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) @@ -96,10 +120,16 @@ def test_fbank_40_bins_no_snip_edges(): fbank = kaldifeat.Fbank(opts) filename = cur_dir / "test_data/test.wav" - wave = read_wave(filename).to(device) + wave = read_wave(filename) features = fbank(wave) + assert features.device.type == "cpu" gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt") + assert torch.allclose(features, gt, rtol=1e-1) + + wave = wave.to(device) + features = fbank(wave) + assert features.device == device assert torch.allclose(features.cpu(), gt, rtol=1e-1) @@ -123,7 +153,7 @@ def test_fbank_chunk(): opts.frame_opts.snip_edges = False fbank = kaldifeat.Fbank(opts) - wave = read_wave(filename).to(device) + wave = read_wave(filename) # You can use # @@ -132,16 +162,27 @@ def test_fbank_chunk(): # to view memory consumption # # 100 frames per chunk - features = fbank(wave, chunk_size=100 * 10) - print(features.shape) + features1 = fbank(wave, chunk_size=100 * 10) + features2 = fbank(wave) + assert torch.allclose(features1, features2) + assert features1.device == features2.device + assert features1.device.type == "cpu" + + if device.type == "cuda": + wave = wave.to(device) + features1 = fbank(wave, chunk_size=100 * 10) + features2 = fbank(wave) + assert torch.allclose(features1, features2) + assert features1.device == features2.device + assert features1.device == device def test_fbank_batch(): - print("=====test_fbank_chunk=====") + print("=====test_fbank_batch=====") for device in get_devices(): print("device", device) - wave0 = read_wave(cur_dir / "test_data/test.wav").to(device) - wave1 = read_wave(cur_dir / "test_data/test2.wav").to(device) + wave0 = read_wave(cur_dir / "test_data/test.wav") + wave1 = read_wave(cur_dir / "test_data/test2.wav") opts = kaldifeat.FbankOptions() opts.device = device @@ -156,6 +197,18 @@ def test_fbank_batch(): assert torch.allclose(features[0], features0) assert torch.allclose(features[1], features1) + if device.type == "cuda": + wave0 = wave0.to(device) + wave1 = wave1.to(device) + + features = fbank([wave0, wave1], chunk_size=10) + + features0 = fbank(wave0) + features1 = fbank(wave1) + + assert torch.allclose(features[0], features0) + assert torch.allclose(features[1], features1) + def test_pickle(): for device in get_devices():