mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
style check
This commit is contained in:
parent
065cbefdc1
commit
77b4224686
@ -830,7 +830,7 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -1115,16 +1115,16 @@ def display_and_save_batch(
|
|||||||
logging.info(f"Saving batch to {filename}")
|
logging.info(f"Saving batch to {filename}")
|
||||||
torch.save(batch, filename)
|
torch.save(batch, filename)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
texts = batch["supervisions"]["text"]
|
||||||
features = batch["inputs"]
|
features = batch["inputs"]
|
||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
y = graph_compiler.texts_to_ids(texts)
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
if type(y) == list:
|
if type(y) == list:
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y)
|
||||||
else:
|
else:
|
||||||
y = y.to(device)
|
y = y
|
||||||
|
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user