mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
fix grad scaler
This commit is contained in:
parent
f186e1d427
commit
a1277c9ae9
4
.github/workflows/librispeech.yml
vendored
4
.github/workflows/librispeech.yml
vendored
@ -30,8 +30,8 @@ jobs:
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
|
||||
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
librispeech:
|
||||
needs: generate_build_matrix
|
||||
|
@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -82,12 +81,13 @@ from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -421,7 +421,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -630,7 +630,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -966,7 +966,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import CTCModel
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -93,6 +92,7 @@ from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
create_grad_scaler,
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
@ -494,7 +494,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -695,7 +695,7 @@ def train_one_epoch(
|
||||
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1005,7 +1005,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -96,6 +95,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
@ -566,7 +566,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -733,7 +733,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1008,7 +1008,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -96,6 +95,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
@ -566,7 +566,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -733,7 +733,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1007,7 +1007,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -66,7 +66,6 @@ from lstm import RNN
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -82,6 +81,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -522,7 +522,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -718,7 +718,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1024,7 +1024,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -74,7 +74,6 @@ from lstm import RNN
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -88,12 +87,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -561,7 +561,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -773,7 +773,7 @@ def train_one_epoch(
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1177,7 +1177,7 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
logging.info("Skip scan_pessimistic_batches_for_oom")
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -66,7 +66,6 @@ from lstm import RNN
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -80,12 +79,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -552,7 +552,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -748,7 +748,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1068,7 +1068,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -10,10 +10,10 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
from scaling import ScaledLinear
|
||||
from torch import Tensor, nn
|
||||
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch_scheduled_sampling import sample_combined
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
from icefall.utils import create_grad_scaler, torch_autocast
|
||||
|
||||
# The main exports of this file are the module KnowledgeBaseLookup and the
|
||||
# function create_knowledge_base.
|
||||
@ -332,7 +332,7 @@ def _test_knowledge_base_lookup_autocast():
|
||||
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
|
||||
m = m.to(device)
|
||||
|
||||
scaler = GradScaler(enabled=True)
|
||||
scaler = create_grad_scaler(enabled=True)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
||||
|
@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -77,6 +76,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
create_grad_scaler,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
@ -459,7 +459,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -614,7 +614,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -874,7 +874,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from noam import Noam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -71,6 +70,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
@ -502,7 +502,7 @@ def save_checkpoint(
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, and training stats to file.
|
||||
@ -656,7 +656,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -945,7 +945,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -89,12 +88,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -524,7 +524,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -717,7 +717,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -1001,7 +1001,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -74,7 +74,6 @@ from librispeech import LibriSpeech
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -85,12 +84,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -547,7 +547,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -756,7 +756,7 @@ def train_one_epoch(
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -1127,7 +1127,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -94,12 +93,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -549,7 +549,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -745,7 +745,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1048,7 +1048,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -82,12 +81,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -572,7 +572,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -769,7 +769,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1079,7 +1079,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -96,10 +95,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
torch_autocast,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -520,7 +520,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -737,7 +737,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1040,7 +1040,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -679,7 +678,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -858,7 +857,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1220,7 +1219,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -84,13 +83,14 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
symlink_or_copy,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -582,7 +582,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -764,7 +764,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1107,7 +1107,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -86,6 +85,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -589,7 +589,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -788,7 +788,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1129,7 +1129,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -63,7 +63,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -582,7 +581,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -779,7 +778,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1119,7 +1118,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -85,6 +84,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
@ -603,7 +603,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -770,7 +770,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1130,7 +1130,7 @@ def run(rank, world_size, args):
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -70,7 +70,6 @@ from librispeech import LibriSpeech
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -89,6 +88,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
@ -621,7 +621,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -801,7 +801,7 @@ def train_one_epoch(
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1224,7 +1224,7 @@ def run(rank, world_size, args):
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -75,7 +75,6 @@ from librispeech import LibriSpeech
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -94,6 +93,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
@ -614,7 +614,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -796,7 +796,7 @@ def train_one_epoch(
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1225,7 +1225,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -51,7 +51,6 @@ from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
@ -70,12 +69,13 @@ from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import UniqLexicon
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
@ -551,7 +551,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -758,7 +758,7 @@ def train_one_epoch(
|
||||
phone_lexicon: UniqLexicon,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1093,7 +1093,7 @@ def run(rank, world_size, args):
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
@ -95,6 +94,7 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
create_grad_scaler,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
@ -809,7 +809,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -986,7 +986,7 @@ def train_one_epoch(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dls: torch.utils.data.DataLoader,
|
||||
valid_sets: List[str],
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1374,7 +1374,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -79,7 +79,6 @@ from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
@ -96,12 +95,13 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -830,7 +830,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -1035,7 +1035,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
@ -1448,7 +1448,7 @@ def run(rank, world_size, args):
|
||||
spec_augment=spec_augment,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -67,7 +67,6 @@ from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
@ -86,6 +85,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -806,7 +806,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -983,7 +983,7 @@ def train_one_epoch(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dls: torch.utils.data.DataLoader,
|
||||
valid_sets: List[str],
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1398,7 +1398,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -46,7 +46,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import CTCModel
|
||||
from optim import Eden, LRScheduler, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -68,6 +67,7 @@ from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
@ -539,7 +539,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -693,7 +693,7 @@ def train_one_epoch(
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -993,7 +993,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
@ -96,6 +95,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -819,7 +819,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -996,7 +996,7 @@ def train_one_epoch(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dls: torch.utils.data.DataLoader,
|
||||
valid_sets: List[str],
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1407,7 +1407,7 @@ def run(rank, world_size, args):
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -76,7 +76,6 @@ from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
@ -94,6 +93,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -708,7 +708,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -884,7 +884,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1253,7 +1253,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import CTCModel
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -91,6 +90,7 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
create_grad_scaler,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -515,7 +515,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -697,7 +697,7 @@ def train_one_epoch(
|
||||
mmi_graph_compiler: MmiTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -1038,7 +1038,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -27,7 +27,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@ -43,7 +42,7 @@ def save_checkpoint(
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
@ -102,7 +101,7 @@ def load_checkpoint(
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
strict: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
@ -201,7 +200,7 @@ def save_checkpoint_with_global_batch_idx(
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
|
@ -57,6 +57,25 @@ Pathlike = Union[str, Path]
|
||||
TORCH_VERSION = version.parse(torch.__version__)
|
||||
|
||||
|
||||
def create_grad_scaler(device="cuda", **kwargs):
|
||||
"""
|
||||
Creates a GradScaler compatible with both torch < 2.0 and >= 2.0.
|
||||
Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
|
||||
|
||||
/icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning:
|
||||
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use
|
||||
`torch.amp.GradScaler('cuda', args...)` instead.
|
||||
"""
|
||||
if TORCH_VERSION >= version.parse("2.0.0"):
|
||||
from torch.amp import GradScaler
|
||||
|
||||
return GradScaler(device=device, **kwargs)
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=FutureWarning)
|
||||
return torch.cuda.amp.GradScaler(**kwargs)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_autocast(device_type="cuda", **kwargs):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user