mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
refacotring
This commit is contained in:
parent
10c099ac90
commit
8cb1cda040
@ -38,7 +38,7 @@ from lhotse.audio import RecordingSet
|
||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||
from lhotse.supervision import SupervisionSet
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
from matcha.audio import mel_spectrogram
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
|
@ -73,8 +73,6 @@ class ModelWrapper(torch.nn.Module):
|
||||
)["mel"]
|
||||
# mel: (batch_size, feat_dim, num_frames)
|
||||
|
||||
# audio = self.vocoder(mel).clamp(-1, 1).squeeze(1)
|
||||
|
||||
return mel
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=2810,
|
||||
default=4000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
@ -44,6 +44,13 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=Path,
|
||||
default="./generator_v1",
|
||||
help="Path to the vocoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
@ -61,6 +68,7 @@ def get_parser():
|
||||
|
||||
|
||||
def load_vocoder(checkpoint_path):
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if checkpoint_path.endswith("v1"):
|
||||
h = AttributeDict(v1)
|
||||
elif checkpoint_path.endswith("v2"):
|
||||
@ -142,10 +150,17 @@ def main():
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
|
||||
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.eval()
|
||||
|
||||
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
|
||||
if not Path(params.vocoder).is_file():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
|
||||
texts = [
|
||||
|
@ -90,7 +90,7 @@ def save_checkpoint(
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
assert k not in checkpoint
|
||||
assert k not in checkpoint, k
|
||||
checkpoint[k] = v
|
||||
|
||||
torch.save(checkpoint, filename)
|
||||
|
Loading…
x
Reference in New Issue
Block a user