Update train.py

This commit is contained in:
zr_jin 2024-11-04 14:26:51 +08:00
parent 8eb160e287
commit 729e86edb9

View File

@ -7,7 +7,7 @@ import json
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import k2
import numpy as np
@ -344,7 +344,7 @@ def compute_validation_loss(
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
@ -392,23 +392,6 @@ def compute_validation_loss(
# summary stats
tot_loss = tot_loss + loss_info
# infer for first batch:
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
)
assert audio_len_pred == len(audio_pred), (
audio_len_pred,
len(audio_pred),
)
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt)
if world_size > 1:
tot_loss.reduce(device)
@ -417,7 +400,7 @@ def compute_validation_loss(
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss, returned_sample
return tot_loss
def train_one_epoch(
@ -577,7 +560,7 @@ def train_one_epoch(
if params.batch_idx_train % params.valid_interval == 1:
logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss(
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
@ -595,18 +578,6 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
"train/valid_speech_hat",
speech_hat,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valid_speech",
speech,
params.batch_idx_train,
params.sampling_rate,
)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
params.train_loss = loss_value