fix grad scaler

This commit is contained in:
k2-fsa 2025-07-01 00:05:08 +08:00
parent f186e1d427
commit a1277c9ae9
33 changed files with 152 additions and 136 deletions

View File

@ -30,8 +30,8 @@ jobs:
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
echo "::set-output name=matrix::${MATRIX}"
librispeech:
needs: generate_build_matrix

View File

@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -82,12 +81,13 @@ from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -421,7 +421,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -630,7 +630,7 @@ def train_one_epoch(
scheduler: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -966,7 +966,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed
from model import CTCModel
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -93,6 +92,7 @@ from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
create_grad_scaler,
torch_autocast,
AttributeDict,
MetricsTracker,
@ -494,7 +494,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -695,7 +695,7 @@ def train_one_epoch(
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1005,7 +1005,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -96,6 +95,7 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
@ -566,7 +566,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -733,7 +733,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1008,7 +1008,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -96,6 +95,7 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
@ -566,7 +566,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -733,7 +733,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1007,7 +1007,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -66,7 +66,6 @@ from lstm import RNN
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -82,6 +81,7 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
@ -522,7 +522,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -718,7 +718,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1024,7 +1024,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -74,7 +74,6 @@ from lstm import RNN
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -88,12 +87,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -561,7 +561,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -773,7 +773,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1177,7 +1177,7 @@ def run(rank, world_size, args):
else:
logging.info("Skip scan_pessimistic_batches_for_oom")
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -66,7 +66,6 @@ from lstm import RNN
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -80,12 +79,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -552,7 +552,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -748,7 +748,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1068,7 +1068,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -10,10 +10,10 @@ from typing import Optional, Tuple
import torch
from scaling import ScaledLinear
from torch import Tensor, nn
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd
from torch.cuda.amp import custom_bwd, custom_fwd
from torch_scheduled_sampling import sample_combined
from icefall.utils import torch_autocast
from icefall.utils import create_grad_scaler, torch_autocast
# The main exports of this file are the module KnowledgeBaseLookup and the
# function create_knowledge_base.
@ -332,7 +332,7 @@ def _test_knowledge_base_lookup_autocast():
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
m = m.to(device)
scaler = GradScaler(enabled=True)
scaler = create_grad_scaler(enabled=True)
start = timeit.default_timer()

View File

@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -77,6 +76,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
create_grad_scaler,
AttributeDict,
MetricsTracker,
setup_logger,
@ -459,7 +459,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -614,7 +614,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -874,7 +874,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from noam import Noam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -71,6 +70,7 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
@ -502,7 +502,7 @@ def save_checkpoint(
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, and training stats to file.
@ -656,7 +656,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -945,7 +945,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -89,12 +88,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -524,7 +524,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -717,7 +717,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -1001,7 +1001,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -74,7 +74,6 @@ from librispeech import LibriSpeech
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -85,12 +84,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -547,7 +547,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -756,7 +756,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: GradScaler,
scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -1127,7 +1127,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -94,12 +93,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -549,7 +549,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -745,7 +745,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1048,7 +1048,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -82,12 +81,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -572,7 +572,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -769,7 +769,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1079,7 +1079,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -96,10 +95,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
torch_autocast,
create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -520,7 +520,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -737,7 +737,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1040,7 +1040,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -679,7 +678,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -858,7 +857,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1220,7 +1219,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -84,13 +83,14 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
filter_uneven_sized_batch,
setup_logger,
str2bool,
symlink_or_copy,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -582,7 +582,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -764,7 +764,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1107,7 +1107,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -86,6 +85,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
@ -589,7 +589,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -788,7 +788,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1129,7 +1129,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -63,7 +63,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -582,7 +581,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -779,7 +778,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1119,7 +1118,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -85,6 +84,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
@ -603,7 +603,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -770,7 +770,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1130,7 +1130,7 @@ def run(rank, world_size, args):
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -70,7 +70,6 @@ from librispeech import LibriSpeech
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -89,6 +88,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
@ -621,7 +621,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -801,7 +801,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1224,7 +1224,7 @@ def run(rank, world_size, args):
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -75,7 +75,6 @@ from librispeech import LibriSpeech
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -94,6 +93,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
@ -614,7 +614,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -796,7 +796,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1225,7 +1225,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -51,7 +51,6 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
@ -70,12 +69,13 @@ from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import UniqLexicon
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@ -551,7 +551,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -758,7 +758,7 @@ def train_one_epoch(
phone_lexicon: UniqLexicon,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1093,7 +1093,7 @@ def run(rank, world_size, args):
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@ -95,6 +94,7 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
create_grad_scaler,
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
@ -809,7 +809,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -986,7 +986,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dls: torch.utils.data.DataLoader,
valid_sets: List[str],
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1374,7 +1374,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -79,7 +79,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@ -96,12 +95,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
create_grad_scaler,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -830,7 +830,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -1035,7 +1035,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
@ -1448,7 +1448,7 @@ def run(rank, world_size, args):
spec_augment=spec_augment,
)
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -67,7 +67,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@ -86,6 +85,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
@ -806,7 +806,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -983,7 +983,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dls: torch.utils.data.DataLoader,
valid_sets: List[str],
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1398,7 +1398,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -46,7 +46,6 @@ from lhotse.utils import fix_random_seed
from model import CTCModel
from optim import Eden, LRScheduler, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
@ -68,6 +67,7 @@ from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
@ -539,7 +539,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -693,7 +693,7 @@ def train_one_epoch(
graph_compiler: BpeCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -993,7 +993,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@ -96,6 +95,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
@ -819,7 +819,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -996,7 +996,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dls: torch.utils.data.DataLoader,
valid_sets: List[str],
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1407,7 +1407,7 @@ def run(rank, world_size, args):
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -76,7 +76,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@ -94,6 +93,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
@ -708,7 +708,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -884,7 +884,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1253,7 +1253,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed
from model import CTCModel
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -91,6 +90,7 @@ from icefall.utils import (
setup_logger,
str2bool,
torch_autocast,
create_grad_scaler,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -515,7 +515,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -697,7 +697,7 @@ def train_one_epoch(
mmi_graph_compiler: MmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -1038,7 +1038,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -27,7 +27,6 @@ import torch
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@ -43,7 +42,7 @@ def save_checkpoint(
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
@ -102,7 +101,7 @@ def load_checkpoint(
model_avg: Optional[nn.Module] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
sampler: Optional[CutSampler] = None,
strict: bool = False,
) -> Dict[str, Any]:
@ -201,7 +200,7 @@ def save_checkpoint_with_global_batch_idx(
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
scaler: Optional["GradScaler"] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):

View File

@ -57,6 +57,25 @@ Pathlike = Union[str, Path]
TORCH_VERSION = version.parse(torch.__version__)
def create_grad_scaler(device="cuda", **kwargs):
"""
Creates a GradScaler compatible with both torch < 2.0 and >= 2.0.
Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
/icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning:
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use
`torch.amp.GradScaler('cuda', args...)` instead.
"""
if TORCH_VERSION >= version.parse("2.0.0"):
from torch.amp import GradScaler
return GradScaler(device=device, **kwargs)
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
return torch.cuda.amp.GradScaler(**kwargs)
@contextmanager
def torch_autocast(device_type="cuda", **kwargs):
"""