mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
style check
This commit is contained in:
parent
065cbefdc1
commit
77b4224686
@ -830,7 +830,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
@ -1115,16 +1115,16 @@ def display_and_save_batch(
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
texts = batch["supervisions"]["text"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
y = graph_compiler.texts_to_ids(texts)
|
||||
if type(y) == list:
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
y = k2.RaggedTensor(y)
|
||||
else:
|
||||
y = y.to(device)
|
||||
y = y
|
||||
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user