Only move necessary data to GPU for computation. (#23)

This commit is contained in:
Fangjun Kuang 2021-12-01 20:02:20 +08:00 committed by GitHub
parent d2652a2c49
commit d6274e7d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 16 deletions

View File

@ -42,6 +42,8 @@ class OfflineFeature(nn.Module):
Kaldi, you should scale the samples to [-32768, 32767] before
calling this function. Note: You are not required to scale them if
you don't care about the compatibility with Kaldi.
**Note:** It does not have to be on the same device as
`self.opts.device`.
vtln_warp
The VTLN warping factor that the user wants to be applied when
computing features for this utterance. Will normally be 1.0,
@ -57,6 +59,8 @@ class OfflineFeature(nn.Module):
input is a list of 1-D tensors. The returned list has as many elements
as the input list.
Return a single 2-D tensor if the input is a single tensor.
Note: The returned `features` is on the same device as the input
waves.
"""
if isinstance(waves, list):
is_list = True
@ -105,10 +109,15 @@ class OfflineFeature(nn.Module):
Return a 2-D tensor with as many rows as the input tensor. Its
number of columns is the number mel bins.
"""
x_device = x.device
self_device = self.opts.device
assert x.ndim == 2
assert x.dtype == torch.float32
if chunk_size is None:
features = self.computer.compute_features(x, vtln_warp)
features = self.computer.compute_features(
x.to(self_device), vtln_warp
)
features = features.to(x_device)
else:
assert chunk_size > 0
num_chunks = x.size(0) // chunk_size
@ -118,12 +127,14 @@ class OfflineFeature(nn.Module):
start = i * chunk_size
end = start + chunk_size
this_chunk = self.computer.compute_features(
x[start:end], vtln_warp
x[start:end].to(self_device), vtln_warp
)
features.append(this_chunk)
features.append(this_chunk.to(x_device))
if end < x.size(0):
last_chunk = self.computer.compute_features(x[end:], vtln_warp)
features.append(last_chunk)
last_chunk = self.computer.compute_features(
x[end:].to(self_device), vtln_warp
)
features.append(last_chunk.to(x_device))
features = torch.cat(features, dim=0)
return features

View File

@ -22,10 +22,16 @@ def test_fbank_default():
opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -41,10 +47,16 @@ def test_fbank_htk():
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -59,10 +71,16 @@ def test_fbank_with_energy():
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
wave = read_wave(filename)
features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
assert torch.allclose(features, gt, rtol=1e-1)
assert features.device.type == "cpu"
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -77,10 +95,16 @@ def test_fbank_40_bins():
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -96,10 +120,16 @@ def test_fbank_40_bins_no_snip_edges():
fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device)
wave = read_wave(filename)
features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -123,7 +153,7 @@ def test_fbank_chunk():
opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts)
wave = read_wave(filename).to(device)
wave = read_wave(filename)
# You can use
#
@ -132,16 +162,27 @@ def test_fbank_chunk():
# to view memory consumption
#
# 100 frames per chunk
features = fbank(wave, chunk_size=100 * 10)
print(features.shape)
features1 = fbank(wave, chunk_size=100 * 10)
features2 = fbank(wave)
assert torch.allclose(features1, features2)
assert features1.device == features2.device
assert features1.device.type == "cpu"
if device.type == "cuda":
wave = wave.to(device)
features1 = fbank(wave, chunk_size=100 * 10)
features2 = fbank(wave)
assert torch.allclose(features1, features2)
assert features1.device == features2.device
assert features1.device == device
def test_fbank_batch():
print("=====test_fbank_chunk=====")
print("=====test_fbank_batch=====")
for device in get_devices():
print("device", device)
wave0 = read_wave(cur_dir / "test_data/test.wav").to(device)
wave1 = read_wave(cur_dir / "test_data/test2.wav").to(device)
wave0 = read_wave(cur_dir / "test_data/test.wav")
wave1 = read_wave(cur_dir / "test_data/test2.wav")
opts = kaldifeat.FbankOptions()
opts.device = device
@ -156,6 +197,18 @@ def test_fbank_batch():
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
if device.type == "cuda":
wave0 = wave0.to(device)
wave1 = wave1.to(device)
features = fbank([wave0, wave1], chunk_size=10)
features0 = fbank(wave0)
features1 = fbank(wave1)
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
def test_pickle():
for device in get_devices():