diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 03cd98e41..0dade163b 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -830,7 +830,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -1115,16 +1115,16 @@ def display_and_save_batch( logging.info(f"Saving batch to {filename}") torch.save(batch, filename) - supervisions = batch["supervisions"] + texts = batch["supervisions"]["text"] features = batch["inputs"] logging.info(f"features shape: {features.shape}") y = graph_compiler.texts_to_ids(texts) if type(y) == list: - y = k2.RaggedTensor(y).to(device) + y = k2.RaggedTensor(y) else: - y = y.to(device) + y = y num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}")