Update train.py

This commit is contained in:
Mingshuang Luo 2021-09-29 16:54:21 +08:00 committed by GitHub
parent a0994ee58a
commit de9b2a9cd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,11 +4,10 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple, List from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
@ -414,8 +413,8 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]" f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
f"tot_loss[{tot_loss}], batch size: {batch_size}" f"tot_loss[{tot_loss}], batch size: {batch_size}")
)
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
if tb_writer is not None: if tb_writer is not None: