Only move necessary data to GPU for computation. (#23)

This commit is contained in:
Fangjun Kuang 2021-12-01 20:02:20 +08:00 committed by GitHub
parent d2652a2c49
commit d6274e7d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 16 deletions

View File

@ -42,6 +42,8 @@ class OfflineFeature(nn.Module):
Kaldi, you should scale the samples to [-32768, 32767] before Kaldi, you should scale the samples to [-32768, 32767] before
calling this function. Note: You are not required to scale them if calling this function. Note: You are not required to scale them if
you don't care about the compatibility with Kaldi. you don't care about the compatibility with Kaldi.
**Note:** It does not have to be on the same device as
`self.opts.device`.
vtln_warp vtln_warp
The VTLN warping factor that the user wants to be applied when The VTLN warping factor that the user wants to be applied when
computing features for this utterance. Will normally be 1.0, computing features for this utterance. Will normally be 1.0,
@ -57,6 +59,8 @@ class OfflineFeature(nn.Module):
input is a list of 1-D tensors. The returned list has as many elements input is a list of 1-D tensors. The returned list has as many elements
as the input list. as the input list.
Return a single 2-D tensor if the input is a single tensor. Return a single 2-D tensor if the input is a single tensor.
Note: The returned `features` is on the same device as the input
waves.
""" """
if isinstance(waves, list): if isinstance(waves, list):
is_list = True is_list = True
@ -105,10 +109,15 @@ class OfflineFeature(nn.Module):
Return a 2-D tensor with as many rows as the input tensor. Its Return a 2-D tensor with as many rows as the input tensor. Its
number of columns is the number mel bins. number of columns is the number mel bins.
""" """
x_device = x.device
self_device = self.opts.device
assert x.ndim == 2 assert x.ndim == 2
assert x.dtype == torch.float32 assert x.dtype == torch.float32
if chunk_size is None: if chunk_size is None:
features = self.computer.compute_features(x, vtln_warp) features = self.computer.compute_features(
x.to(self_device), vtln_warp
)
features = features.to(x_device)
else: else:
assert chunk_size > 0 assert chunk_size > 0
num_chunks = x.size(0) // chunk_size num_chunks = x.size(0) // chunk_size
@ -118,12 +127,14 @@ class OfflineFeature(nn.Module):
start = i * chunk_size start = i * chunk_size
end = start + chunk_size end = start + chunk_size
this_chunk = self.computer.compute_features( this_chunk = self.computer.compute_features(
x[start:end], vtln_warp x[start:end].to(self_device), vtln_warp
) )
features.append(this_chunk) features.append(this_chunk.to(x_device))
if end < x.size(0): if end < x.size(0):
last_chunk = self.computer.compute_features(x[end:], vtln_warp) last_chunk = self.computer.compute_features(
features.append(last_chunk) x[end:].to(self_device), vtln_warp
)
features.append(last_chunk.to(x_device))
features = torch.cat(features, dim=0) features = torch.cat(features, dim=0)
return features return features

View File

@ -22,10 +22,16 @@ def test_fbank_default():
opts.frame_opts.dither = 0 opts.frame_opts.dither = 0
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav" filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device) wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test.txt") gt = read_ark_txt(cur_dir / "test_data/test.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -41,10 +47,16 @@ def test_fbank_htk():
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav" filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device) wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt") gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -59,10 +71,16 @@ def test_fbank_with_energy():
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav" filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device) wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt") gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
assert torch.allclose(features, gt, rtol=1e-1)
assert features.device.type == "cpu"
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -77,10 +95,16 @@ def test_fbank_40_bins():
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav" filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device) wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-40.txt") gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -96,10 +120,16 @@ def test_fbank_40_bins_no_snip_edges():
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
filename = cur_dir / "test_data/test.wav" filename = cur_dir / "test_data/test.wav"
wave = read_wave(filename).to(device) wave = read_wave(filename)
features = fbank(wave) features = fbank(wave)
assert features.device.type == "cpu"
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt") gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
assert torch.allclose(features, gt, rtol=1e-1)
wave = wave.to(device)
features = fbank(wave)
assert features.device == device
assert torch.allclose(features.cpu(), gt, rtol=1e-1) assert torch.allclose(features.cpu(), gt, rtol=1e-1)
@ -123,7 +153,7 @@ def test_fbank_chunk():
opts.frame_opts.snip_edges = False opts.frame_opts.snip_edges = False
fbank = kaldifeat.Fbank(opts) fbank = kaldifeat.Fbank(opts)
wave = read_wave(filename).to(device) wave = read_wave(filename)
# You can use # You can use
# #
@ -132,16 +162,27 @@ def test_fbank_chunk():
# to view memory consumption # to view memory consumption
# #
# 100 frames per chunk # 100 frames per chunk
features = fbank(wave, chunk_size=100 * 10) features1 = fbank(wave, chunk_size=100 * 10)
print(features.shape) features2 = fbank(wave)
assert torch.allclose(features1, features2)
assert features1.device == features2.device
assert features1.device.type == "cpu"
if device.type == "cuda":
wave = wave.to(device)
features1 = fbank(wave, chunk_size=100 * 10)
features2 = fbank(wave)
assert torch.allclose(features1, features2)
assert features1.device == features2.device
assert features1.device == device
def test_fbank_batch(): def test_fbank_batch():
print("=====test_fbank_chunk=====") print("=====test_fbank_batch=====")
for device in get_devices(): for device in get_devices():
print("device", device) print("device", device)
wave0 = read_wave(cur_dir / "test_data/test.wav").to(device) wave0 = read_wave(cur_dir / "test_data/test.wav")
wave1 = read_wave(cur_dir / "test_data/test2.wav").to(device) wave1 = read_wave(cur_dir / "test_data/test2.wav")
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
opts.device = device opts.device = device
@ -156,6 +197,18 @@ def test_fbank_batch():
assert torch.allclose(features[0], features0) assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1) assert torch.allclose(features[1], features1)
if device.type == "cuda":
wave0 = wave0.to(device)
wave1 = wave1.to(device)
features = fbank([wave0, wave1], chunk_size=10)
features0 = fbank(wave0)
features1 = fbank(wave1)
assert torch.allclose(features[0], features0)
assert torch.allclose(features[1], features1)
def test_pickle(): def test_pickle():
for device in get_devices(): for device in get_devices():