mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-09 18:12:17 +00:00
Only move necessary data to GPU for computation. (#23)
This commit is contained in:
parent
d2652a2c49
commit
d6274e7d41
@ -42,6 +42,8 @@ class OfflineFeature(nn.Module):
|
||||
Kaldi, you should scale the samples to [-32768, 32767] before
|
||||
calling this function. Note: You are not required to scale them if
|
||||
you don't care about the compatibility with Kaldi.
|
||||
**Note:** It does not have to be on the same device as
|
||||
`self.opts.device`.
|
||||
vtln_warp
|
||||
The VTLN warping factor that the user wants to be applied when
|
||||
computing features for this utterance. Will normally be 1.0,
|
||||
@ -57,6 +59,8 @@ class OfflineFeature(nn.Module):
|
||||
input is a list of 1-D tensors. The returned list has as many elements
|
||||
as the input list.
|
||||
Return a single 2-D tensor if the input is a single tensor.
|
||||
Note: The returned `features` is on the same device as the input
|
||||
waves.
|
||||
"""
|
||||
if isinstance(waves, list):
|
||||
is_list = True
|
||||
@ -105,10 +109,15 @@ class OfflineFeature(nn.Module):
|
||||
Return a 2-D tensor with as many rows as the input tensor. Its
|
||||
number of columns is the number mel bins.
|
||||
"""
|
||||
x_device = x.device
|
||||
self_device = self.opts.device
|
||||
assert x.ndim == 2
|
||||
assert x.dtype == torch.float32
|
||||
if chunk_size is None:
|
||||
features = self.computer.compute_features(x, vtln_warp)
|
||||
features = self.computer.compute_features(
|
||||
x.to(self_device), vtln_warp
|
||||
)
|
||||
features = features.to(x_device)
|
||||
else:
|
||||
assert chunk_size > 0
|
||||
num_chunks = x.size(0) // chunk_size
|
||||
@ -118,12 +127,14 @@ class OfflineFeature(nn.Module):
|
||||
start = i * chunk_size
|
||||
end = start + chunk_size
|
||||
this_chunk = self.computer.compute_features(
|
||||
x[start:end], vtln_warp
|
||||
x[start:end].to(self_device), vtln_warp
|
||||
)
|
||||
features.append(this_chunk)
|
||||
features.append(this_chunk.to(x_device))
|
||||
if end < x.size(0):
|
||||
last_chunk = self.computer.compute_features(x[end:], vtln_warp)
|
||||
features.append(last_chunk)
|
||||
last_chunk = self.computer.compute_features(
|
||||
x[end:].to(self_device), vtln_warp
|
||||
)
|
||||
features.append(last_chunk.to(x_device))
|
||||
features = torch.cat(features, dim=0)
|
||||
|
||||
return features
|
||||
|
@ -22,10 +22,16 @@ def test_fbank_default():
|
||||
opts.frame_opts.dither = 0
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
assert features.device.type == "cpu"
|
||||
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
|
||||
@ -41,10 +47,16 @@ def test_fbank_htk():
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
assert features.device.type == "cpu"
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
|
||||
@ -59,10 +71,16 @@ def test_fbank_with_energy():
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
assert features.device.type == "cpu"
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
|
||||
@ -77,10 +95,16 @@ def test_fbank_40_bins():
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
assert features.device.type == "cpu"
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
|
||||
@ -96,10 +120,16 @@ def test_fbank_40_bins_no_snip_edges():
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
filename = cur_dir / "test_data/test.wav"
|
||||
wave = read_wave(filename).to(device)
|
||||
wave = read_wave(filename)
|
||||
|
||||
features = fbank(wave)
|
||||
assert features.device.type == "cpu"
|
||||
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
||||
assert torch.allclose(features, gt, rtol=1e-1)
|
||||
|
||||
wave = wave.to(device)
|
||||
features = fbank(wave)
|
||||
assert features.device == device
|
||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||
|
||||
|
||||
@ -123,7 +153,7 @@ def test_fbank_chunk():
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
wave = read_wave(filename).to(device)
|
||||
wave = read_wave(filename)
|
||||
|
||||
# You can use
|
||||
#
|
||||
@ -132,16 +162,27 @@ def test_fbank_chunk():
|
||||
# to view memory consumption
|
||||
#
|
||||
# 100 frames per chunk
|
||||
features = fbank(wave, chunk_size=100 * 10)
|
||||
print(features.shape)
|
||||
features1 = fbank(wave, chunk_size=100 * 10)
|
||||
features2 = fbank(wave)
|
||||
assert torch.allclose(features1, features2)
|
||||
assert features1.device == features2.device
|
||||
assert features1.device.type == "cpu"
|
||||
|
||||
if device.type == "cuda":
|
||||
wave = wave.to(device)
|
||||
features1 = fbank(wave, chunk_size=100 * 10)
|
||||
features2 = fbank(wave)
|
||||
assert torch.allclose(features1, features2)
|
||||
assert features1.device == features2.device
|
||||
assert features1.device == device
|
||||
|
||||
|
||||
def test_fbank_batch():
|
||||
print("=====test_fbank_chunk=====")
|
||||
print("=====test_fbank_batch=====")
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
wave0 = read_wave(cur_dir / "test_data/test.wav").to(device)
|
||||
wave1 = read_wave(cur_dir / "test_data/test2.wav").to(device)
|
||||
wave0 = read_wave(cur_dir / "test_data/test.wav")
|
||||
wave1 = read_wave(cur_dir / "test_data/test2.wav")
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
@ -156,6 +197,18 @@ def test_fbank_batch():
|
||||
assert torch.allclose(features[0], features0)
|
||||
assert torch.allclose(features[1], features1)
|
||||
|
||||
if device.type == "cuda":
|
||||
wave0 = wave0.to(device)
|
||||
wave1 = wave1.to(device)
|
||||
|
||||
features = fbank([wave0, wave1], chunk_size=10)
|
||||
|
||||
features0 = fbank(wave0)
|
||||
features1 = fbank(wave1)
|
||||
|
||||
assert torch.allclose(features[0], features0)
|
||||
assert torch.allclose(features[1], features1)
|
||||
|
||||
|
||||
def test_pickle():
|
||||
for device in get_devices():
|
||||
|
Loading…
x
Reference in New Issue
Block a user