mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Update train.py
This commit is contained in:
parent
8eb160e287
commit
729e86edb9
@ -7,7 +7,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -344,7 +344,7 @@ def compute_validation_loss(
|
|||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
) -> MetricsTracker:
|
||||||
"""Run the validation process."""
|
"""Run the validation process."""
|
||||||
model.eval()
|
model.eval()
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
@ -392,23 +392,6 @@ def compute_validation_loss(
|
|||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
# infer for first batch:
|
|
||||||
if batch_idx == 0 and rank == 0:
|
|
||||||
inner_model = model.module if isinstance(model, DDP) else model
|
|
||||||
audio_pred, _, duration = inner_model.inference(
|
|
||||||
text=tokens[0, : tokens_lens[0].item()]
|
|
||||||
)
|
|
||||||
audio_pred = audio_pred.data.cpu().numpy()
|
|
||||||
audio_len_pred = (
|
|
||||||
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
|
||||||
)
|
|
||||||
assert audio_len_pred == len(audio_pred), (
|
|
||||||
audio_len_pred,
|
|
||||||
len(audio_pred),
|
|
||||||
)
|
|
||||||
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
|
||||||
returned_sample = (audio_pred, audio_gt)
|
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(device)
|
tot_loss.reduce(device)
|
||||||
|
|
||||||
@ -417,7 +400,7 @@ def compute_validation_loss(
|
|||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
params.best_valid_loss = loss_value
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
return tot_loss, returned_sample
|
return tot_loss
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -577,7 +560,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if params.batch_idx_train % params.valid_interval == 1:
|
if params.batch_idx_train % params.valid_interval == 1:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -595,18 +578,6 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
|
||||||
"train/valid_speech_hat",
|
|
||||||
speech_hat,
|
|
||||||
params.batch_idx_train,
|
|
||||||
params.sampling_rate,
|
|
||||||
)
|
|
||||||
tb_writer.add_audio(
|
|
||||||
"train/valid_speech",
|
|
||||||
speech,
|
|
||||||
params.batch_idx_train,
|
|
||||||
params.sampling_rate,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user