mirror of
https://github.com/csukuangfj/kaldifeat.git
synced 2025-08-10 18:42:17 +00:00
Only move necessary data to GPU for computation. (#23)
This commit is contained in:
parent
d2652a2c49
commit
d6274e7d41
@ -42,6 +42,8 @@ class OfflineFeature(nn.Module):
|
|||||||
Kaldi, you should scale the samples to [-32768, 32767] before
|
Kaldi, you should scale the samples to [-32768, 32767] before
|
||||||
calling this function. Note: You are not required to scale them if
|
calling this function. Note: You are not required to scale them if
|
||||||
you don't care about the compatibility with Kaldi.
|
you don't care about the compatibility with Kaldi.
|
||||||
|
**Note:** It does not have to be on the same device as
|
||||||
|
`self.opts.device`.
|
||||||
vtln_warp
|
vtln_warp
|
||||||
The VTLN warping factor that the user wants to be applied when
|
The VTLN warping factor that the user wants to be applied when
|
||||||
computing features for this utterance. Will normally be 1.0,
|
computing features for this utterance. Will normally be 1.0,
|
||||||
@ -57,6 +59,8 @@ class OfflineFeature(nn.Module):
|
|||||||
input is a list of 1-D tensors. The returned list has as many elements
|
input is a list of 1-D tensors. The returned list has as many elements
|
||||||
as the input list.
|
as the input list.
|
||||||
Return a single 2-D tensor if the input is a single tensor.
|
Return a single 2-D tensor if the input is a single tensor.
|
||||||
|
Note: The returned `features` is on the same device as the input
|
||||||
|
waves.
|
||||||
"""
|
"""
|
||||||
if isinstance(waves, list):
|
if isinstance(waves, list):
|
||||||
is_list = True
|
is_list = True
|
||||||
@ -105,10 +109,15 @@ class OfflineFeature(nn.Module):
|
|||||||
Return a 2-D tensor with as many rows as the input tensor. Its
|
Return a 2-D tensor with as many rows as the input tensor. Its
|
||||||
number of columns is the number mel bins.
|
number of columns is the number mel bins.
|
||||||
"""
|
"""
|
||||||
|
x_device = x.device
|
||||||
|
self_device = self.opts.device
|
||||||
assert x.ndim == 2
|
assert x.ndim == 2
|
||||||
assert x.dtype == torch.float32
|
assert x.dtype == torch.float32
|
||||||
if chunk_size is None:
|
if chunk_size is None:
|
||||||
features = self.computer.compute_features(x, vtln_warp)
|
features = self.computer.compute_features(
|
||||||
|
x.to(self_device), vtln_warp
|
||||||
|
)
|
||||||
|
features = features.to(x_device)
|
||||||
else:
|
else:
|
||||||
assert chunk_size > 0
|
assert chunk_size > 0
|
||||||
num_chunks = x.size(0) // chunk_size
|
num_chunks = x.size(0) // chunk_size
|
||||||
@ -118,12 +127,14 @@ class OfflineFeature(nn.Module):
|
|||||||
start = i * chunk_size
|
start = i * chunk_size
|
||||||
end = start + chunk_size
|
end = start + chunk_size
|
||||||
this_chunk = self.computer.compute_features(
|
this_chunk = self.computer.compute_features(
|
||||||
x[start:end], vtln_warp
|
x[start:end].to(self_device), vtln_warp
|
||||||
)
|
)
|
||||||
features.append(this_chunk)
|
features.append(this_chunk.to(x_device))
|
||||||
if end < x.size(0):
|
if end < x.size(0):
|
||||||
last_chunk = self.computer.compute_features(x[end:], vtln_warp)
|
last_chunk = self.computer.compute_features(
|
||||||
features.append(last_chunk)
|
x[end:].to(self_device), vtln_warp
|
||||||
|
)
|
||||||
|
features.append(last_chunk.to(x_device))
|
||||||
features = torch.cat(features, dim=0)
|
features = torch.cat(features, dim=0)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
@ -22,10 +22,16 @@ def test_fbank_default():
|
|||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
filename = cur_dir / "test_data/test.wav"
|
||||||
wave = read_wave(filename).to(device)
|
wave = read_wave(filename)
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
gt = read_ark_txt(cur_dir / "test_data/test.txt")
|
||||||
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
|
||||||
|
wave = wave.to(device)
|
||||||
|
features = fbank(wave)
|
||||||
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
@ -41,10 +47,16 @@ def test_fbank_htk():
|
|||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
filename = cur_dir / "test_data/test.wav"
|
||||||
wave = read_wave(filename).to(device)
|
wave = read_wave(filename)
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
gt = read_ark_txt(cur_dir / "test_data/test-htk.txt")
|
||||||
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
|
||||||
|
wave = wave.to(device)
|
||||||
|
features = fbank(wave)
|
||||||
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
@ -59,10 +71,16 @@ def test_fbank_with_energy():
|
|||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
filename = cur_dir / "test_data/test.wav"
|
||||||
wave = read_wave(filename).to(device)
|
wave = read_wave(filename)
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
gt = read_ark_txt(cur_dir / "test_data/test-with-energy.txt")
|
||||||
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
assert features.device.type == "cpu"
|
||||||
|
|
||||||
|
wave = wave.to(device)
|
||||||
|
features = fbank(wave)
|
||||||
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
@ -77,10 +95,16 @@ def test_fbank_40_bins():
|
|||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
filename = cur_dir / "test_data/test.wav"
|
||||||
wave = read_wave(filename).to(device)
|
wave = read_wave(filename)
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
gt = read_ark_txt(cur_dir / "test_data/test-40.txt")
|
||||||
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
|
||||||
|
wave = wave.to(device)
|
||||||
|
features = fbank(wave)
|
||||||
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
@ -96,10 +120,16 @@ def test_fbank_40_bins_no_snip_edges():
|
|||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
filename = cur_dir / "test_data/test.wav"
|
filename = cur_dir / "test_data/test.wav"
|
||||||
wave = read_wave(filename).to(device)
|
wave = read_wave(filename)
|
||||||
|
|
||||||
features = fbank(wave)
|
features = fbank(wave)
|
||||||
|
assert features.device.type == "cpu"
|
||||||
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
gt = read_ark_txt(cur_dir / "test_data/test-40-no-snip-edges.txt")
|
||||||
|
assert torch.allclose(features, gt, rtol=1e-1)
|
||||||
|
|
||||||
|
wave = wave.to(device)
|
||||||
|
features = fbank(wave)
|
||||||
|
assert features.device == device
|
||||||
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
assert torch.allclose(features.cpu(), gt, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
@ -123,7 +153,7 @@ def test_fbank_chunk():
|
|||||||
opts.frame_opts.snip_edges = False
|
opts.frame_opts.snip_edges = False
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
wave = read_wave(filename).to(device)
|
wave = read_wave(filename)
|
||||||
|
|
||||||
# You can use
|
# You can use
|
||||||
#
|
#
|
||||||
@ -132,16 +162,27 @@ def test_fbank_chunk():
|
|||||||
# to view memory consumption
|
# to view memory consumption
|
||||||
#
|
#
|
||||||
# 100 frames per chunk
|
# 100 frames per chunk
|
||||||
features = fbank(wave, chunk_size=100 * 10)
|
features1 = fbank(wave, chunk_size=100 * 10)
|
||||||
print(features.shape)
|
features2 = fbank(wave)
|
||||||
|
assert torch.allclose(features1, features2)
|
||||||
|
assert features1.device == features2.device
|
||||||
|
assert features1.device.type == "cpu"
|
||||||
|
|
||||||
|
if device.type == "cuda":
|
||||||
|
wave = wave.to(device)
|
||||||
|
features1 = fbank(wave, chunk_size=100 * 10)
|
||||||
|
features2 = fbank(wave)
|
||||||
|
assert torch.allclose(features1, features2)
|
||||||
|
assert features1.device == features2.device
|
||||||
|
assert features1.device == device
|
||||||
|
|
||||||
|
|
||||||
def test_fbank_batch():
|
def test_fbank_batch():
|
||||||
print("=====test_fbank_chunk=====")
|
print("=====test_fbank_batch=====")
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
wave0 = read_wave(cur_dir / "test_data/test.wav").to(device)
|
wave0 = read_wave(cur_dir / "test_data/test.wav")
|
||||||
wave1 = read_wave(cur_dir / "test_data/test2.wav").to(device)
|
wave1 = read_wave(cur_dir / "test_data/test2.wav")
|
||||||
|
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
opts.device = device
|
opts.device = device
|
||||||
@ -156,6 +197,18 @@ def test_fbank_batch():
|
|||||||
assert torch.allclose(features[0], features0)
|
assert torch.allclose(features[0], features0)
|
||||||
assert torch.allclose(features[1], features1)
|
assert torch.allclose(features[1], features1)
|
||||||
|
|
||||||
|
if device.type == "cuda":
|
||||||
|
wave0 = wave0.to(device)
|
||||||
|
wave1 = wave1.to(device)
|
||||||
|
|
||||||
|
features = fbank([wave0, wave1], chunk_size=10)
|
||||||
|
|
||||||
|
features0 = fbank(wave0)
|
||||||
|
features1 = fbank(wave1)
|
||||||
|
|
||||||
|
assert torch.allclose(features[0], features0)
|
||||||
|
assert torch.allclose(features[1], features1)
|
||||||
|
|
||||||
|
|
||||||
def test_pickle():
|
def test_pickle():
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user