From 1b89c6dac4ed4ad472917e0cd873733ba1ed4fdd Mon Sep 17 00:00:00 2001 From: Your Name <> Date: Mon, 28 Oct 2024 22:58:40 -0700 Subject: [PATCH] skipping batch counts hurts performance --- egs/librispeech/SSL/hubert/finetune.py | 2 +- egs/librispeech/SSL/hubert/finetune_ce.py | 2 +- egs/librispeech/SSL/zipformer/finetune.py | 2 +- egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 582771dee..05b942f63 100755 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -99,7 +99,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: * params.accum_grad * (params.max_duration * params.world_size) / params.ref_duration - ) + 100000 + ) def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index cec42ea12..1081313f1 100755 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -99,7 +99,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: * params.accum_grad * (params.max_duration * params.world_size) / params.ref_duration - ) + 100000 + ) def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index 336c35813..2e521f177 100755 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -99,7 +99,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: * params.accum_grad * (params.max_duration * params.world_size) / params.ref_duration - ) + 100000 + ) def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: diff --git a/egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py b/egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py index 572e040bd..d5dd8d71f 100755 --- a/egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py +++ b/egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py @@ -93,7 +93,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: * params.accum_grad * (params.max_duration * params.world_size) / params.ref_duration - ) + 100000 + ) def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: