style check

This commit is contained in:
luomingshuang 2022-06-26 20:37:11 +08:00
parent 065cbefdc1
commit 77b4224686

View File

@ -830,7 +830,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@ -1115,16 +1115,16 @@ def display_and_save_batch(
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = graph_compiler.texts_to_ids(texts)
if type(y) == list:
y = k2.RaggedTensor(y).to(device)
y = k2.RaggedTensor(y)
else:
y = y.to(device)
y = y
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")