mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
black formatted
This commit is contained in:
parent
d752230287
commit
01a003a675
@ -132,7 +132,7 @@ def test():
|
|||||||
for rep in range(4):
|
for rep in range(4):
|
||||||
length: int = torch.randint(10, 2_000, (1,)).item()
|
length: int = torch.randint(10, 2_000, (1,)).item()
|
||||||
bits: int = torch.randint(1, 16, (1,)).item()
|
bits: int = torch.randint(1, 16, (1,)).item()
|
||||||
tokens: List[int] = torch.randint(2 ** bits, (length,)).tolist()
|
tokens: List[int] = torch.randint(2**bits, (length,)).tolist()
|
||||||
rebuilt: List[int] = []
|
rebuilt: List[int] = []
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
packer = BitPacker(bits, buf)
|
packer = BitPacker(bits, buf)
|
||||||
|
@ -235,7 +235,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.wav_to_specs = []
|
self.wav_to_specs = []
|
||||||
for i in range(5, 12):
|
for i in range(5, 12):
|
||||||
s = 2 ** i
|
s = 2**i
|
||||||
self.wav_to_specs.append(
|
self.wav_to_specs.append(
|
||||||
MelSpectrogram(
|
MelSpectrogram(
|
||||||
sample_rate=sampling_rate,
|
sample_rate=sampling_rate,
|
||||||
|
@ -161,7 +161,7 @@ class SEANetEncoder(nn.Module):
|
|||||||
SEANetResnetBlock(
|
SEANetResnetBlock(
|
||||||
mult * n_filters,
|
mult * n_filters,
|
||||||
kernel_sizes=[residual_kernel_size, 1],
|
kernel_sizes=[residual_kernel_size, 1],
|
||||||
dilations=[dilation_base ** j, 1],
|
dilations=[dilation_base**j, 1],
|
||||||
norm=norm,
|
norm=norm,
|
||||||
norm_params=norm_params,
|
norm_params=norm_params,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
@ -311,7 +311,7 @@ class SEANetDecoder(nn.Module):
|
|||||||
SEANetResnetBlock(
|
SEANetResnetBlock(
|
||||||
mult * n_filters // 2,
|
mult * n_filters // 2,
|
||||||
kernel_sizes=[residual_kernel_size, 1],
|
kernel_sizes=[residual_kernel_size, 1],
|
||||||
dilations=[dilation_base ** j, 1],
|
dilations=[dilation_base**j, 1],
|
||||||
activation=activation,
|
activation=activation,
|
||||||
activation_params=activation_params,
|
activation_params=activation_params,
|
||||||
norm=norm,
|
norm=norm,
|
||||||
|
@ -41,7 +41,7 @@ def build_stable_quantized_cdf(
|
|||||||
if roundoff:
|
if roundoff:
|
||||||
pdf = (pdf / roundoff).floor() * roundoff
|
pdf = (pdf / roundoff).floor() * roundoff
|
||||||
# interpolate with uniform distribution to achieve desired minimum probability.
|
# interpolate with uniform distribution to achieve desired minimum probability.
|
||||||
total_range = 2 ** total_range_bits
|
total_range = 2**total_range_bits
|
||||||
cardinality = len(pdf)
|
cardinality = len(pdf)
|
||||||
alpha = min_range * cardinality / total_range
|
alpha = min_range * cardinality / total_range
|
||||||
assert alpha <= 1, "you must reduce min_range"
|
assert alpha <= 1, "you must reduce min_range"
|
||||||
@ -51,7 +51,7 @@ def build_stable_quantized_cdf(
|
|||||||
if min_range < 2:
|
if min_range < 2:
|
||||||
raise ValueError("min_range must be at least 2.")
|
raise ValueError("min_range must be at least 2.")
|
||||||
if check:
|
if check:
|
||||||
assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
|
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
|
||||||
if (
|
if (
|
||||||
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
|
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
|
||||||
).any() or quantized_cdf[0] < min_range:
|
).any() or quantized_cdf[0] < min_range:
|
||||||
@ -142,7 +142,7 @@ class ArithmeticCoder:
|
|||||||
quantized_cdf (Tensor): use `build_stable_quantized_cdf`
|
quantized_cdf (Tensor): use `build_stable_quantized_cdf`
|
||||||
to build this from your pdf estimate.
|
to build this from your pdf estimate.
|
||||||
"""
|
"""
|
||||||
while self.delta < 2 ** self.total_range_bits:
|
while self.delta < 2**self.total_range_bits:
|
||||||
self.low *= 2
|
self.low *= 2
|
||||||
self.high = self.high * 2 + 1
|
self.high = self.high * 2 + 1
|
||||||
self.max_bit += 1
|
self.max_bit += 1
|
||||||
@ -150,10 +150,10 @@ class ArithmeticCoder:
|
|||||||
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
||||||
range_high = quantized_cdf[symbol].item() - 1
|
range_high = quantized_cdf[symbol].item() - 1
|
||||||
effective_low = int(
|
effective_low = int(
|
||||||
math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))
|
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
|
||||||
)
|
)
|
||||||
effective_high = int(
|
effective_high = int(
|
||||||
math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))
|
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
|
||||||
)
|
)
|
||||||
assert self.low <= self.high
|
assert self.low <= self.high
|
||||||
self.high = self.low + effective_high
|
self.high = self.low + effective_high
|
||||||
@ -238,7 +238,7 @@ class ArithmeticDecoder:
|
|||||||
to build this from your pdf estimate. This must be **exatly**
|
to build this from your pdf estimate. This must be **exatly**
|
||||||
the same cdf as the one used at encoding time.
|
the same cdf as the one used at encoding time.
|
||||||
"""
|
"""
|
||||||
while self.delta < 2 ** self.total_range_bits:
|
while self.delta < 2**self.total_range_bits:
|
||||||
bit = self.unpacker.pull()
|
bit = self.unpacker.pull()
|
||||||
if bit is None:
|
if bit is None:
|
||||||
return None
|
return None
|
||||||
@ -255,10 +255,10 @@ class ArithmeticDecoder:
|
|||||||
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
||||||
range_high = quantized_cdf[mid].item() - 1
|
range_high = quantized_cdf[mid].item() - 1
|
||||||
effective_low = int(
|
effective_low = int(
|
||||||
math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))
|
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
|
||||||
)
|
)
|
||||||
effective_high = int(
|
effective_high = int(
|
||||||
math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))
|
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
|
||||||
)
|
)
|
||||||
low = effective_low + self.low
|
low = effective_low + self.low
|
||||||
high = effective_high + self.low
|
high = effective_high + self.low
|
||||||
|
@ -76,7 +76,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
|||||||
|
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
||||||
dists = -(diffs ** 2).sum(dim=-1)
|
dists = -(diffs**2).sum(dim=-1)
|
||||||
|
|
||||||
buckets = dists.max(dim=-1).indices
|
buckets = dists.max(dim=-1).indices
|
||||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user