From a1277c9ae97d4358ef0e35afb5d1b497f946674d Mon Sep 17 00:00:00 2001 From: k2-fsa Date: Tue, 1 Jul 2025 00:05:08 +0800 Subject: [PATCH] fix grad scaler --- .github/workflows/librispeech.yml | 4 ++-- egs/librispeech/ASR/conformer_ctc2/train.py | 10 +++++----- egs/librispeech/ASR/conformer_ctc3/train.py | 8 ++++---- .../train.py | 8 ++++---- .../train.py | 8 ++++---- .../ASR/lstm_transducer_stateless/train.py | 8 ++++---- .../ASR/lstm_transducer_stateless2/train.py | 10 +++++----- .../ASR/lstm_transducer_stateless3/train.py | 10 +++++----- .../ASR/pruned2_knowledge/sampling.py | 6 +++--- .../ASR/pruned2_knowledge/train.py | 8 ++++---- .../pruned_stateless_emformer_rnnt2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 10 +++++----- .../ASR/pruned_transducer_stateless3/train.py | 10 +++++----- .../ASR/pruned_transducer_stateless4/train.py | 10 +++++----- .../ASR/pruned_transducer_stateless5/train.py | 10 +++++----- .../ASR/pruned_transducer_stateless6/train.py | 10 +++++----- .../pruned_transducer_stateless7/finetune.py | 7 +++---- .../ASR/pruned_transducer_stateless7/train.py | 10 +++++----- .../pruned_transducer_stateless7_ctc/train.py | 8 ++++---- .../train.py | 7 +++---- .../train.py | 8 ++++---- .../train.py | 8 ++++---- .../ASR/pruned_transducer_stateless8/train.py | 8 ++++---- .../ASR/tiny_transducer_ctc/train.py | 10 +++++----- egs/librispeech/ASR/zipformer/finetune.py | 8 ++++---- egs/librispeech/ASR/zipformer/train.py | 10 +++++----- .../ASR/zipformer_adapter/train.py | 8 ++++---- egs/librispeech/ASR/zipformer_ctc/train.py | 8 ++++---- .../ASR/zipformer_lora/finetune.py | 8 ++++---- egs/librispeech/ASR/zipformer_lora/train.py | 8 ++++---- egs/librispeech/ASR/zipformer_mmi/train.py | 8 ++++---- icefall/checkpoint.py | 7 +++---- icefall/utils.py | 19 +++++++++++++++++++ 33 files changed, 152 insertions(+), 136 deletions(-) diff --git a/.github/workflows/librispeech.yml b/.github/workflows/librispeech.yml index 4b8021254..19037c11b 100644 --- a/.github/workflows/librispeech.yml +++ b/.github/workflows/librispeech.yml @@ -30,8 +30,8 @@ jobs: run: | # outputting for debugging purposes python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") - # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0") + # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0") echo "::set-output name=matrix::${MATRIX}" librispeech: needs: generate_build_matrix diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 112147249..14c132ada 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,12 +81,13 @@ from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -421,7 +421,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -630,7 +630,7 @@ def train_one_epoch( scheduler: LRSchedulerType, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -966,7 +966,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index baf75b1de..7ea639e48 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -93,6 +92,7 @@ from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( + create_grad_scaler, torch_autocast, AttributeDict, MetricsTracker, @@ -494,7 +494,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -695,7 +695,7 @@ def train_one_epoch( graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1005,7 +1005,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index be1407302..fc33f9512 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,6 +95,7 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, setup_logger, str2bool, torch_autocast, @@ -566,7 +566,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -733,7 +733,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1008,7 +1008,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index f4a980637..b00cc6cc6 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,6 +95,7 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, setup_logger, str2bool, torch_autocast, @@ -566,7 +566,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -733,7 +733,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1007,7 +1007,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index 9087adfb9..e23da3b56 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -66,7 +66,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,6 +81,7 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, @@ -522,7 +522,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -718,7 +718,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1024,7 +1024,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 934c0cfb0..1b31b5485 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -74,7 +74,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -88,12 +87,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -561,7 +561,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -773,7 +773,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1177,7 +1177,7 @@ def run(rank, world_size, args): else: logging.info("Skip scan_pessimistic_batches_for_oom") - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index ea4a682f1..e169b499f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -66,7 +66,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -80,12 +79,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -552,7 +552,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -748,7 +748,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1068,7 +1068,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 7e86ceda3..5850555cd 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -10,10 +10,10 @@ from typing import Optional, Tuple import torch from scaling import ScaledLinear from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +from torch.cuda.amp import custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined -from icefall.utils import torch_autocast +from icefall.utils import create_grad_scaler, torch_autocast # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. @@ -332,7 +332,7 @@ def _test_knowledge_base_lookup_autocast(): optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) - scaler = GradScaler(enabled=True) + scaler = create_grad_scaler(enabled=True) start = timeit.default_timer() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 83d5446e4..0611fd8cb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -77,6 +76,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( + create_grad_scaler, AttributeDict, MetricsTracker, setup_logger, @@ -459,7 +459,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -614,7 +614,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -874,7 +874,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 192e104f7..2af8f3f4c 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from noam import Noam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -71,6 +70,7 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, setup_logger, str2bool, torch_autocast, @@ -502,7 +502,7 @@ def save_checkpoint( model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, and training stats to file. @@ -656,7 +656,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -945,7 +945,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a33b1f4c1..ce6c89614 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -89,12 +88,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -524,7 +524,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -717,7 +717,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -1001,7 +1001,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 3dab2eb8b..50670d1b2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -74,7 +74,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -85,12 +84,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -547,7 +547,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -756,7 +756,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -1127,7 +1127,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index bfc9fc144..c35f52309 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -94,12 +93,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -549,7 +549,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -745,7 +745,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1048,7 +1048,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 748f4826d..6f9f92623 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,12 +81,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -572,7 +572,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -769,7 +769,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1079,7 +1079,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index d3e373321..35ee74f15 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,10 +95,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, - torch_autocast, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -520,7 +520,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -737,7 +737,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1040,7 +1040,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 072aa274c..832dde49d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -679,7 +678,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -858,7 +857,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1220,7 +1219,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 0c3852455..f94da9788 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -84,13 +83,14 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, filter_uneven_sized_batch, setup_logger, str2bool, symlink_or_copy, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -582,7 +582,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -764,7 +764,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1107,7 +1107,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index 88fd43ee4..a26f11c82 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,6 +85,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, @@ -589,7 +589,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -788,7 +788,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1129,7 +1129,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 6112c4a0b..85f50c1e0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -63,7 +63,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -582,7 +581,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -779,7 +778,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1119,7 +1118,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index c6dd8dd4c..4d8a2644d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -85,6 +84,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, setup_logger, str2bool, torch_autocast, @@ -603,7 +603,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -770,7 +770,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1130,7 +1130,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index 3b5178851..4b97575e6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -70,7 +70,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -89,6 +88,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, setup_logger, str2bool, torch_autocast, @@ -621,7 +621,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -801,7 +801,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1224,7 +1224,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 32b899f52..ad14ec9dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -75,7 +75,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -94,6 +93,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, setup_logger, str2bool, torch_autocast, @@ -614,7 +614,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -796,7 +796,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1225,7 +1225,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index dea2b096a..368bd20fa 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -51,7 +51,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW from torch.optim.lr_scheduler import StepLR @@ -70,12 +69,13 @@ from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import UniqLexicon from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -551,7 +551,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -758,7 +758,7 @@ def train_one_epoch( phone_lexicon: UniqLexicon, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1093,7 +1093,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index c143fbd76..94e8b273a 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -95,6 +94,7 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + create_grad_scaler, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -809,7 +809,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -986,7 +986,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1374,7 +1374,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9e05ca638..42ae9b9f2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -79,7 +79,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -96,12 +95,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( - torch_autocast, AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -830,7 +830,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -1035,7 +1035,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, @@ -1448,7 +1448,7 @@ def run(rank, world_size, args): spec_augment=spec_augment, ) - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 99c852844..fcd7272e9 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -67,7 +67,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -86,6 +85,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, @@ -806,7 +806,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -983,7 +983,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1398,7 +1398,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index 7c772bb3b..bd3bfa332 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -46,7 +46,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, LRScheduler, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -68,6 +67,7 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, setup_logger, str2bool, torch_autocast, @@ -539,7 +539,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -693,7 +693,7 @@ def train_one_epoch( graph_compiler: BpeCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -993,7 +993,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index ea6e2877b..c26a2f5cc 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -96,6 +95,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, @@ -819,7 +819,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -996,7 +996,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1407,7 +1407,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index add50fa25..2b83d58ef 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -76,7 +76,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -94,6 +93,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, @@ -708,7 +708,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -884,7 +884,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1253,7 +1253,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index 0b99ab64e..c33263be8 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -91,6 +90,7 @@ from icefall.utils import ( setup_logger, str2bool, torch_autocast, + create_grad_scaler, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -515,7 +515,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -697,7 +697,7 @@ def train_one_epoch( mmi_graph_compiler: MmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1038,7 +1038,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index f98045d29..4ab685684 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -27,7 +27,6 @@ import torch import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -43,7 +42,7 @@ def save_checkpoint( params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -102,7 +101,7 @@ def load_checkpoint( model_avg: Optional[nn.Module] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, strict: bool = False, ) -> Dict[str, Any]: @@ -201,7 +200,7 @@ def save_checkpoint_with_global_batch_idx( params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ): diff --git a/icefall/utils.py b/icefall/utils.py index 405b6b327..4017d9e9e 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -57,6 +57,25 @@ Pathlike = Union[str, Path] TORCH_VERSION = version.parse(torch.__version__) +def create_grad_scaler(device="cuda", **kwargs): + """ + Creates a GradScaler compatible with both torch < 2.0 and >= 2.0. + Accepts all kwargs like: enabled, init_scale, growth_factor, etc. + + /icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning: + `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use + `torch.amp.GradScaler('cuda', args...)` instead. + """ + if TORCH_VERSION >= version.parse("2.0.0"): + from torch.amp import GradScaler + + return GradScaler(device=device, **kwargs) + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + return torch.cuda.amp.GradScaler(**kwargs) + + @contextmanager def torch_autocast(device_type="cuda", **kwargs): """