Update train.py

This commit is contained in:
zr_jin 2024-11-04 14:17:04 +08:00
parent ce07ceeffb
commit 8eb160e287

View File

@ -15,9 +15,9 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from matcha.model import fix_len_compatibility from model import fix_len_compatibility
from matcha.models.matcha_tts import MatchaTTS from models.matcha_tts import MatchaTTS
from matcha.tokenizer import Tokenizer from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer