added licensing info

This commit is contained in:
zr_jin 2024-10-21 10:55:27 +08:00
parent 283157268a
commit d752230287
18 changed files with 88 additions and 29 deletions

View File

@ -1,3 +1,21 @@
#!/usr/bin/env python3
# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple from typing import List, Tuple
import torch import torch
@ -222,17 +240,12 @@ class DiscriminatorSTFT(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
fmap = [] fmap = []
# print('x ', x.shape)
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
# print('z ', z.shape)
z = torch.cat([z.real, z.imag], dim=1) z = torch.cat([z.real, z.imag], dim=1)
# print('cat_z ', z.shape)
z = rearrange(z, "b c w t -> b c t w") z = rearrange(z, "b c w t -> b c t w")
for i, layer in enumerate(self.convs): for i, layer in enumerate(self.convs):
z = layer(z) z = layer(z)
z = self.activation(z) z = self.activation(z)
# print('z i', i, z.shape)
fmap.append(z) fmap.append(z)
z = self.conv_post(z) z = self.conv_post(z)
# print('logit ', z.shape)
return z, fmap return z, fmap

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" """Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
import io import io
@ -132,7 +132,7 @@ def test():
for rep in range(4): for rep in range(4):
length: int = torch.randint(10, 2_000, (1,)).item() length: int = torch.randint(10, 2_000, (1,)).item()
bits: int = torch.randint(1, 16, (1,)).item() bits: int = torch.randint(1, 16, (1,)).item()
tokens: List[int] = torch.randint(2**bits, (length,)).tolist() tokens: List[int] = torch.randint(2 ** bits, (length,)).tolist()
rebuilt: List[int] = [] rebuilt: List[int] = []
buf = io.BytesIO() buf = io.BytesIO()
packer = BitPacker(bits, buf) packer = BitPacker(bits, buf)

View File

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List from typing import List
import torch import torch

View File

@ -1,3 +1,20 @@
#!/usr/bin/env python3
# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import random import random
from typing import List, Optional from typing import List, Optional

View File

@ -235,7 +235,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
super().__init__() super().__init__()
self.wav_to_specs = [] self.wav_to_specs = []
for i in range(5, 12): for i in range(5, 12):
s = 2**i s = 2 ** i
self.wav_to_specs.append( self.wav_to_specs.append(
MelSpectrogram( MelSpectrogram(
sample_rate=sampling_rate, sample_rate=sampling_rate,

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""Torch modules.""" """Torch modules."""
# flake8: noqa # flake8: noqa
from .conv import ( from .conv import (

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""Convolutional layers wrappers and utilities.""" """Convolutional layers wrappers and utilities."""
import logging import logging
import math import math

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""LSTM layers module.""" """LSTM layers module."""
from torch import nn from torch import nn

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""Normalization modules.""" """Normalization modules."""
from typing import List, Union from typing import List, Union

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""Encodec SEANet-based encoder and decoder implementation.""" """Encodec SEANet-based encoder and decoder implementation."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -161,7 +161,7 @@ class SEANetEncoder(nn.Module):
SEANetResnetBlock( SEANetResnetBlock(
mult * n_filters, mult * n_filters,
kernel_sizes=[residual_kernel_size, 1], kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base**j, 1], dilations=[dilation_base ** j, 1],
norm=norm, norm=norm,
norm_params=norm_params, norm_params=norm_params,
activation=activation, activation=activation,
@ -311,7 +311,7 @@ class SEANetDecoder(nn.Module):
SEANetResnetBlock( SEANetResnetBlock(
mult * n_filters // 2, mult * n_filters // 2,
kernel_sizes=[residual_kernel_size, 1], kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base**j, 1], dilations=[dilation_base ** j, 1],
activation=activation, activation=activation,
activation_params=activation_params, activation_params=activation_params,
norm=norm, norm=norm,

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""A streamable transformer.""" """A streamable transformer."""
import typing as tp import typing as tp
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union

View File

@ -2,6 +2,6 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
# flake8: noqa # flake8: noqa
from .vq import QuantizedResult, ResidualVectorQuantizer from .vq import QuantizedResult, ResidualVectorQuantizer

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""Arithmetic coder.""" """Arithmetic coder."""
import io import io
import math import math
@ -41,7 +41,7 @@ def build_stable_quantized_cdf(
if roundoff: if roundoff:
pdf = (pdf / roundoff).floor() * roundoff pdf = (pdf / roundoff).floor() * roundoff
# interpolate with uniform distribution to achieve desired minimum probability. # interpolate with uniform distribution to achieve desired minimum probability.
total_range = 2**total_range_bits total_range = 2 ** total_range_bits
cardinality = len(pdf) cardinality = len(pdf)
alpha = min_range * cardinality / total_range alpha = min_range * cardinality / total_range
assert alpha <= 1, "you must reduce min_range" assert alpha <= 1, "you must reduce min_range"
@ -51,7 +51,7 @@ def build_stable_quantized_cdf(
if min_range < 2: if min_range < 2:
raise ValueError("min_range must be at least 2.") raise ValueError("min_range must be at least 2.")
if check: if check:
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
if ( if (
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
).any() or quantized_cdf[0] < min_range: ).any() or quantized_cdf[0] < min_range:
@ -142,7 +142,7 @@ class ArithmeticCoder:
quantized_cdf (Tensor): use `build_stable_quantized_cdf` quantized_cdf (Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. to build this from your pdf estimate.
""" """
while self.delta < 2**self.total_range_bits: while self.delta < 2 ** self.total_range_bits:
self.low *= 2 self.low *= 2
self.high = self.high * 2 + 1 self.high = self.high * 2 + 1
self.max_bit += 1 self.max_bit += 1
@ -150,10 +150,10 @@ class ArithmeticCoder:
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1 range_high = quantized_cdf[symbol].item() - 1
effective_low = int( effective_low = int(
math.ceil(range_low * (self.delta / (2**self.total_range_bits))) math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))
) )
effective_high = int( effective_high = int(
math.floor(range_high * (self.delta / (2**self.total_range_bits))) math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))
) )
assert self.low <= self.high assert self.low <= self.high
self.high = self.low + effective_high self.high = self.low + effective_high
@ -238,7 +238,7 @@ class ArithmeticDecoder:
to build this from your pdf estimate. This must be **exatly** to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time. the same cdf as the one used at encoding time.
""" """
while self.delta < 2**self.total_range_bits: while self.delta < 2 ** self.total_range_bits:
bit = self.unpacker.pull() bit = self.unpacker.pull()
if bit is None: if bit is None:
return None return None
@ -255,10 +255,10 @@ class ArithmeticDecoder:
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1 range_high = quantized_cdf[mid].item() - 1
effective_low = int( effective_low = int(
math.ceil(range_low * (self.delta / (2**self.total_range_bits))) math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))
) )
effective_high = int( effective_high = int(
math.floor(range_high * (self.delta / (2**self.total_range_bits))) math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))
) )
low = effective_low + self.low low = effective_low + self.low
high = effective_high + self.low high = effective_high + self.low

View File

@ -76,7 +76,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
for _ in range(num_iters): for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1) dists = -(diffs ** 2).sum(dim=-1)
buckets = dists.max(dim=-1).indices buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters) bins = torch.bincount(buckets, minlength=num_clusters)

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""Torch distributed utilities.""" """Torch distributed utilities."""
from typing import Dict, Iterable, List from typing import Dict, Iterable, List

View File

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""Residual vector quantizer implementation.""" """Residual vector quantizer implementation."""
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field

View File

@ -1,3 +1,8 @@
# original implementation is from https://github.com/ZhikangNiu/encodec-pytorch/blob/main/scheduler.py
# Copyright 2024 Zhi-Kang Niu
# MIT License
import math import math
from bisect import bisect_right from bisect import bisect_right

View File

@ -1,3 +1,21 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Zengwei Yao)
# 2024 The Chinese University of HK (Author: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import itertools import itertools
import logging import logging