From 322597879228911b5ff00b3f93844a78b8527309 Mon Sep 17 00:00:00 2001 From: zzasdf <15218404468@163.com> Date: Wed, 17 Jul 2024 17:07:30 +0800 Subject: [PATCH] fix error in accum_grad --- egs/librispeech/SSL/hubert/finetune.py | 2 +- egs/librispeech/SSL/hubert/finetune_ce.py | 2 +- egs/librispeech/SSL/hubert/pretrain.py | 2 +- egs/librispeech/SSL/hubert/pretrain_ce.py | 2 +- egs/librispeech/SSL/zipformer/finetune.py | 2 +- egs/librispeech/SSL/zipformer/pretrain.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 201847aed..17daa3c9d 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -948,7 +948,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index e69a5a8cd..2723cc770 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -948,7 +948,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index d9bda8857..f183d90fd 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -774,7 +774,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 24c0d4d3a..94948695d 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -774,7 +774,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index bbb445320..c907b41c5 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -1245,7 +1245,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 5f547e0b8..937fb382e 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -1072,7 +1072,7 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx % params.accum_grad != params.accum_grad - 1: + if sub_batch_idx % params.accum_grad != params.accum_grad - 1: optimizer.zero_grad() loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value