Update train.py

This commit is contained in:
Mingshuang Luo 2021-09-29 16:54:21 +08:00 committed by GitHub
parent a0994ee58a
commit de9b2a9cd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,11 +4,10 @@ import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple, List
from typing import Optional, Tuple
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
@ -414,8 +413,8 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}]"
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
f"tot_loss[{tot_loss}], batch size: {batch_size}")
if batch_idx % 10 == 0:
if tb_writer is not None: