mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
Update train.py
This commit is contained in:
parent
8eb160e287
commit
729e86edb9
@ -7,7 +7,7 @@ import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
@ -344,7 +344,7 @@ def compute_validation_loss(
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
@ -392,23 +392,6 @@ def compute_validation_loss(
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
# infer for first batch:
|
||||
if batch_idx == 0 and rank == 0:
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred, _, duration = inner_model.inference(
|
||||
text=tokens[0, : tokens_lens[0].item()]
|
||||
)
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (
|
||||
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
)
|
||||
assert audio_len_pred == len(audio_pred), (
|
||||
audio_len_pred,
|
||||
len(audio_pred),
|
||||
)
|
||||
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||
returned_sample = (audio_pred, audio_gt)
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
||||
@ -417,7 +400,7 @@ def compute_validation_loss(
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss, returned_sample
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -577,7 +560,7 @@ def train_one_epoch(
|
||||
|
||||
if params.batch_idx_train % params.valid_interval == 1:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@ -595,18 +578,6 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valid_speech_hat",
|
||||
speech_hat,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valid_speech",
|
||||
speech,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
|
Loading…
x
Reference in New Issue
Block a user