mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 04:32:19 +00:00
add tensorboard samples
This commit is contained in:
parent
d874bdedac
commit
5d79953f22
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -56,7 +56,7 @@ function infer() {
|
|||||||
|
|
||||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||||
|
|
||||||
./matcha/inference.py \
|
./matcha/infer.py \
|
||||||
--epoch 1 \
|
--epoch 1 \
|
||||||
--exp-dir ./matcha/exp \
|
--exp-dir ./matcha/exp \
|
||||||
--tokens data/tokens.txt \
|
--tokens data/tokens.txt \
|
||||||
|
@ -7,9 +7,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -343,7 +344,7 @@ def compute_validation_loss(
|
|||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> MetricsTracker:
|
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||||
"""Run the validation process."""
|
"""Run the validation process."""
|
||||||
model.eval()
|
model.eval()
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
@ -391,6 +392,23 @@ def compute_validation_loss(
|
|||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
|
# infer for first batch:
|
||||||
|
if batch_idx == 0 and rank == 0:
|
||||||
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
|
audio_pred, _, duration = inner_model.inference(
|
||||||
|
text=tokens[0, : tokens_lens[0].item()]
|
||||||
|
)
|
||||||
|
audio_pred = audio_pred.data.cpu().numpy()
|
||||||
|
audio_len_pred = (
|
||||||
|
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||||
|
)
|
||||||
|
assert audio_len_pred == len(audio_pred), (
|
||||||
|
audio_len_pred,
|
||||||
|
len(audio_pred),
|
||||||
|
)
|
||||||
|
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||||
|
returned_sample = (audio_pred, audio_gt)
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(device)
|
tot_loss.reduce(device)
|
||||||
|
|
||||||
@ -399,7 +417,7 @@ def compute_validation_loss(
|
|||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = loss_value
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
return tot_loss
|
return tot_loss, returned_sample
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -559,7 +577,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if params.batch_idx_train % params.valid_interval == 1:
|
if params.batch_idx_train % params.valid_interval == 1:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -577,6 +595,18 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/valid_speech_hat",
|
||||||
|
speech_hat,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/valid_speech",
|
||||||
|
speech,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user