refacotring

This commit is contained in:
Fangjun Kuang 2024-10-28 19:59:38 +08:00
parent 10c099ac90
commit 8cb1cda040
4 changed files with 19 additions and 6 deletions

View File

@ -38,7 +38,7 @@ from lhotse.audio import RecordingSet
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.supervision import SupervisionSet
from lhotse.utils import Seconds, compute_num_frames
from matcha.utils.audio import mel_spectrogram
from matcha.audio import mel_spectrogram
from icefall.utils import get_executor

View File

@ -73,8 +73,6 @@ class ModelWrapper(torch.nn.Module):
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
# audio = self.vocoder(mel).clamp(-1, 1).squeeze(1)
return mel

View File

@ -28,7 +28,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=2810,
default=4000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
@ -44,6 +44,13 @@ def get_parser():
""",
)
parser.add_argument(
"--vocoder",
type=Path,
default="./generator_v1",
help="Path to the vocoder",
)
parser.add_argument(
"--tokens",
type=Path,
@ -61,6 +68,7 @@ def get_parser():
def load_vocoder(checkpoint_path):
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
@ -142,10 +150,17 @@ def main():
logging.info("About to create model")
model = get_model(params)
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.eval()
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder)
denoiser = Denoiser(vocoder, mode="zeros")
texts = [

View File

@ -90,7 +90,7 @@ def save_checkpoint(
if params:
for k, v in params.items():
assert k not in checkpoint
assert k not in checkpoint, k
checkpoint[k] = v
torch.save(checkpoint, filename)