mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
refacotring
This commit is contained in:
parent
10c099ac90
commit
8cb1cda040
@ -38,7 +38,7 @@ from lhotse.audio import RecordingSet
|
|||||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||||
from lhotse.supervision import SupervisionSet
|
from lhotse.supervision import SupervisionSet
|
||||||
from lhotse.utils import Seconds, compute_num_frames
|
from lhotse.utils import Seconds, compute_num_frames
|
||||||
from matcha.utils.audio import mel_spectrogram
|
from matcha.audio import mel_spectrogram
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
@ -73,8 +73,6 @@ class ModelWrapper(torch.nn.Module):
|
|||||||
)["mel"]
|
)["mel"]
|
||||||
# mel: (batch_size, feat_dim, num_frames)
|
# mel: (batch_size, feat_dim, num_frames)
|
||||||
|
|
||||||
# audio = self.vocoder(mel).clamp(-1, 1).squeeze(1)
|
|
||||||
|
|
||||||
return mel
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=2810,
|
default=4000,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 1.
|
Note: Epoch counts from 1.
|
||||||
""",
|
""",
|
||||||
@ -44,6 +44,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocoder",
|
||||||
|
type=Path,
|
||||||
|
default="./generator_v1",
|
||||||
|
help="Path to the vocoder",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokens",
|
"--tokens",
|
||||||
type=Path,
|
type=Path,
|
||||||
@ -61,6 +68,7 @@ def get_parser():
|
|||||||
|
|
||||||
|
|
||||||
def load_vocoder(checkpoint_path):
|
def load_vocoder(checkpoint_path):
|
||||||
|
checkpoint_path = str(checkpoint_path)
|
||||||
if checkpoint_path.endswith("v1"):
|
if checkpoint_path.endswith("v1"):
|
||||||
h = AttributeDict(v1)
|
h = AttributeDict(v1)
|
||||||
elif checkpoint_path.endswith("v2"):
|
elif checkpoint_path.endswith("v2"):
|
||||||
@ -142,10 +150,17 @@ def main():
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
|
||||||
|
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
|
||||||
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
|
if not Path(params.vocoder).is_file():
|
||||||
|
raise ValueError(f"{params.vocoder} does not exist")
|
||||||
|
|
||||||
|
vocoder = load_vocoder(params.vocoder)
|
||||||
denoiser = Denoiser(vocoder, mode="zeros")
|
denoiser = Denoiser(vocoder, mode="zeros")
|
||||||
|
|
||||||
texts = [
|
texts = [
|
||||||
|
@ -90,7 +90,7 @@ def save_checkpoint(
|
|||||||
|
|
||||||
if params:
|
if params:
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
assert k not in checkpoint
|
assert k not in checkpoint, k
|
||||||
checkpoint[k] = v
|
checkpoint[k] = v
|
||||||
|
|
||||||
torch.save(checkpoint, filename)
|
torch.save(checkpoint, filename)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user