Replace deprecated pytorch methods (#1814)

* Replace deprecated pytorch methods

- torch.cuda.amp.GradScaler(...) => torch.amp.GradScaler("cuda", ...)
- torch.cuda.amp.autocast(...) => torch.amp.autocast("cuda", ...)

* Replace `with autocast(...)` with `with autocast("cuda", ...)`


Co-authored-by: Li Peng <lipeng@unisound.ai>
This commit is contained in:
Li Peng 2024-12-16 10:24:16 +08:00 committed by GitHub
parent d475de5600
commit 3e4da5f781
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
147 changed files with 520 additions and 518 deletions

View File

@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -638,7 +638,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -843,7 +843,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -912,7 +912,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -60,7 +60,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -688,7 +688,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -888,7 +888,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -989,7 +989,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -184,7 +184,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -219,7 +219,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -79,7 +79,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -797,7 +797,7 @@ def train_one_epoch(
aishell = is_aishell(batch["supervisions"]["cut"][0]) aishell = is_aishell(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1096,7 +1096,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -74,7 +74,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -812,7 +812,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1107,7 +1107,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -70,7 +70,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -809,7 +809,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1107,7 +1107,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -802,7 +802,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1102,7 +1102,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer_for_ncnn_export_only import Zipformer from zipformer_for_ncnn_export_only import Zipformer
@ -813,7 +813,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1105,7 +1105,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1205,7 +1205,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -812,7 +812,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1104,7 +1104,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -62,7 +62,7 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.functional import pad as pad_tensor from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -514,7 +514,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -608,7 +608,7 @@ def train_one_epoch(
) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -812,7 +812,7 @@ def run(rank, world_size, args):
train_dl = aishell.train_dataloaders(aishell.train_cuts()) train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -71,7 +71,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -910,7 +910,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1201,7 +1201,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1302,7 +1302,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -61,7 +61,7 @@ from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from train import ( from train import (
@ -495,7 +495,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -795,7 +795,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -895,7 +895,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -75,7 +75,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -734,7 +734,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -963,7 +963,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1062,7 +1062,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -68,7 +68,7 @@ from local.text_normalize import text_normalize
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -727,7 +727,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
# print(batch["supervisions"]) # print(batch["supervisions"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -963,7 +963,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1034,7 +1034,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -638,7 +638,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -843,7 +843,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -912,7 +912,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -782,7 +782,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1127,7 +1127,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -773,7 +773,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1034,7 +1034,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1134,7 +1134,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -61,7 +61,7 @@ from model import SURT
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScaledLinear, ScaledLSTM from scaling import ScaledLinear, ScaledLSTM
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -1067,7 +1067,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0] batch_size = batch["inputs"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1314,7 +1314,7 @@ def run(rank, world_size, args):
) )
valid_dl = ami.valid_dataloaders(dev_cuts) valid_dl = ami.valid_dataloaders(dev_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -61,7 +61,7 @@ from model import SURT
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScaledLinear, ScaledLSTM from scaling import ScaledLinear, ScaledLSTM
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -1058,7 +1058,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0] batch_size = batch["inputs"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1305,7 +1305,7 @@ def run(rank, world_size, args):
) )
valid_dl = ami.valid_dataloaders(dev_cuts) valid_dl = ami.valid_dataloaders(dev_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -53,7 +53,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -799,7 +799,7 @@ def train_one_epoch(
num_samples += batch_size num_samples += batch_size
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1057,7 +1057,7 @@ def run(rank, world_size, args):
valid_cuts = audioset.audioset_eval_cuts() valid_cuts = audioset.audioset_eval_cuts()
valid_dl = audioset.valid_dataloaders(valid_cuts) valid_dl = audioset.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1148,7 +1148,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -825,7 +825,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1120,7 +1120,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1220,7 +1220,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer_for_ncnn_export_only import Zipformer from zipformer_for_ncnn_export_only import Zipformer
@ -818,7 +818,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1109,7 +1109,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1209,7 +1209,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -68,7 +68,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -895,7 +895,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1193,7 +1193,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1293,7 +1293,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -840,7 +840,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1137,7 +1137,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1237,7 +1237,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -969,7 +969,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1265,7 +1265,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1365,7 +1365,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -67,7 +67,7 @@ from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from train import ( from train import (
@ -604,7 +604,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -784,7 +784,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,
@ -979,7 +979,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -67,7 +67,7 @@ from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from tokenizer import Tokenizer from tokenizer import Tokenizer
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer_for_ncnn_export_only import Zipformer from zipformer_for_ncnn_export_only import Zipformer
@ -839,7 +839,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1146,7 +1146,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1246,7 +1246,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -67,7 +67,7 @@ from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from tokenizer import Tokenizer from tokenizer import Tokenizer
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -838,7 +838,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1145,7 +1145,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1245,7 +1245,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -675,7 +675,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -873,7 +873,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -944,7 +944,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -958,7 +958,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1217,7 +1217,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1317,7 +1317,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -73,7 +73,7 @@ from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from train import ( from train import (
@ -291,7 +291,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -570,7 +570,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -961,7 +961,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1220,7 +1220,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -61,7 +61,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -805,7 +805,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1096,7 +1096,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1196,7 +1196,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -70,7 +70,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -942,7 +942,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1233,7 +1233,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1333,7 +1333,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -140,7 +140,7 @@ class SURT(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -175,7 +175,7 @@ class SURT(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -287,7 +287,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -1065,7 +1065,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
def backward(ctx, x_grad: Tensor): def backward(ctx, x_grad: Tensor):
(x_orig,) = ctx.saved_tensors (x_orig,) = ctx.saved_tensors
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -1263,7 +1263,7 @@ class MaxEig(torch.nn.Module):
): ):
return _no_op(x) return _no_op(x)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
eps = 1.0e-20 eps = 1.0e-20
orig_x = x orig_x = x
x = x.to(torch.float32) x = x.to(torch.float32)

View File

@ -69,7 +69,7 @@ from model import SURT
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScaledLSTM from scaling import ScaledLSTM
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -1096,7 +1096,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0] batch_size = batch["inputs"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1343,7 +1343,7 @@ def run(rank, world_size, args):
train_dl_ov40 = libricss.train_dataloaders(train_cuts_ov40) train_dl_ov40 = libricss.train_dataloaders(train_cuts_ov40)
valid_dl = libricss.valid_dataloaders(dev_cuts) valid_dl = libricss.valid_dataloaders(dev_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -67,7 +67,7 @@ from model import SURT
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScaledLinear, ScaledLSTM from scaling import ScaledLinear, ScaledLSTM
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -985,7 +985,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0] batch_size = batch["inputs"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1237,7 +1237,7 @@ def run(rank, world_size, args):
) )
valid_dl = libricss.valid_dataloaders(dev_cuts) valid_dl = libricss.valid_dataloaders(dev_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -78,7 +78,7 @@ from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from text_normalization import remove_punc_to_upper from text_normalization import remove_punc_to_upper
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -958,7 +958,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1268,7 +1268,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1367,7 +1367,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -186,7 +186,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -221,7 +221,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -245,7 +245,7 @@ class PromptedTransducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -287,7 +287,7 @@ class PromptedTransducer(nn.Module):
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False) logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -271,7 +271,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -685,7 +685,7 @@ class BalancerFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x = x.to(torch.float32) x = x.to(torch.float32)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -940,7 +940,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -1280,7 +1280,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08 coeff = -0.08
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1351,7 +1351,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True

View File

@ -89,7 +89,7 @@ from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from text_normalization import train_text_normalization, upper_only_alpha from text_normalization import train_text_normalization, upper_only_alpha
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -975,7 +975,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1271,7 +1271,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1371,7 +1371,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -103,7 +103,7 @@ from text_normalization import (
upper_only_alpha, upper_only_alpha,
) )
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -1321,7 +1321,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1647,7 +1647,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1749,7 +1749,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -1561,7 +1561,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)
@ -1844,7 +1844,7 @@ class MultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -1116,7 +1116,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1407,7 +1407,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1505,7 +1505,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -57,7 +57,7 @@ from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from ssl_datamodule import LibriLightDataModule from ssl_datamodule import LibriLightDataModule
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -936,7 +936,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0] batch_size = batch["kmeans"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1229,7 +1229,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -65,7 +65,7 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -676,7 +676,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -965,7 +965,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1036,7 +1036,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -76,7 +76,7 @@ from lhotse.utils import fix_random_seed
from model import CTCModel from model import CTCModel
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -743,7 +743,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1004,7 +1004,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1073,7 +1073,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -772,7 +772,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1002,7 +1002,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1071,7 +1071,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -774,7 +774,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1003,7 +1003,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1074,7 +1074,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -772,7 +772,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1001,7 +1001,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1072,7 +1072,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -156,7 +156,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -192,7 +192,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -66,7 +66,7 @@ from lstm import RNN
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -763,7 +763,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1023,7 +1023,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1092,7 +1092,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -74,7 +74,7 @@ from lstm import RNN
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -848,7 +848,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1176,7 +1176,7 @@ def run(rank, world_size, args):
else: else:
logging.info("Skip scan_pessimistic_batches_for_oom") logging.info("Skip scan_pessimistic_batches_for_oom")
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1247,7 +1247,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -66,7 +66,7 @@ from lstm import RNN
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -793,7 +793,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1067,7 +1067,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1136,7 +1136,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -141,7 +141,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -176,7 +176,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -10,7 +10,7 @@ from typing import Optional, Tuple
import torch import torch
from scaling import ScaledLinear from scaling import ScaledLinear
from torch import Tensor, nn from torch import Tensor, nn
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch.amp import GradScaler, custom_bwd, custom_fwd
from torch_scheduled_sampling import sample_combined from torch_scheduled_sampling import sample_combined
# The main exports of this file are the module KnowledgeBaseLookup and the # The main exports of this file are the module KnowledgeBaseLookup and the
@ -330,14 +330,14 @@ def _test_knowledge_base_lookup_autocast():
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
m = m.to(device) m = m.to(device)
scaler = GradScaler(enabled=True) scaler = GradScaler("cuda", enabled=True)
start = timeit.default_timer() start = timeit.default_timer()
for epoch in range(150): for epoch in range(150):
for n, (x, y) in enumerate(train_pairs): for n, (x, y) in enumerate(train_pairs):
y_out = m(x) y_out = m(x)
with torch.cuda.amp.autocast(enabled=True): with torch.amp.autocast("cuda", enabled=True):
loss = ((y_out - y) ** 2).mean() * 100.0 loss = ((y_out - y) ** 2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0: if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")

View File

@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -650,7 +650,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -868,7 +868,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -937,7 +937,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from noam import Noam from noam import Noam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -693,7 +693,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -939,7 +939,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1004,7 +1004,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -157,7 +157,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -193,7 +193,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -78,7 +78,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -759,7 +759,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1000,7 +1000,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0, warmup=0.0 if params.start_epoch == 0 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1067,7 +1067,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -74,7 +74,7 @@ from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -827,7 +827,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1126,7 +1126,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0, warmup=0.0 if params.start_epoch == 0 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1195,7 +1195,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -789,7 +789,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1047,7 +1047,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1116,7 +1116,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -68,7 +68,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -814,7 +814,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1078,7 +1078,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1147,7 +1147,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -185,7 +185,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -220,7 +220,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -781,7 +781,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1039,7 +1039,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler("cuda", enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1108,7 +1108,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -903,7 +903,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1219,7 +1219,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1319,7 +1319,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -150,7 +150,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -289,7 +289,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -669,7 +669,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
def backward(ctx, x_grad: Tensor): def backward(ctx, x_grad: Tensor):
(x_orig,) = ctx.saved_tensors (x_orig,) = ctx.saved_tensors
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -867,7 +867,7 @@ class MaxEig(torch.nn.Module):
): ):
return _no_op(x) return _no_op(x)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
eps = 1.0e-20 eps = 1.0e-20
orig_x = x orig_x = x
x = x.to(torch.float32) x = x.to(torch.float32)

View File

@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -809,7 +809,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1106,7 +1106,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads bsz = n // num_heads
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32) attn_output = attn_output.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (

View File

@ -150,7 +150,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -833,7 +833,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1128,7 +1128,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1228,7 +1228,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -178,7 +178,7 @@ class Transducer(nn.Module):
am = self.simple_am_proj(encoder_out_fr) am = self.simple_am_proj(encoder_out_fr)
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -213,7 +213,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -822,7 +822,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1118,7 +1118,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1217,7 +1217,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer_for_ncnn_export_only import Zipformer from zipformer_for_ncnn_export_only import Zipformer
@ -811,7 +811,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1106,7 +1106,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -810,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1124,7 +1124,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1224,7 +1224,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads bsz = n // num_heads
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32) attn_output = attn_output.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (

View File

@ -2708,7 +2708,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads bsz = n // num_heads
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32) attn_output = attn_output.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (

View File

@ -70,7 +70,7 @@ from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -866,7 +866,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1218,7 +1218,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -172,7 +172,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -207,7 +207,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -75,7 +75,7 @@ from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -866,7 +866,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1219,7 +1219,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1321,7 +1321,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -51,7 +51,7 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
@ -809,7 +809,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1092,7 +1092,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1198,7 +1198,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -78,7 +78,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -1049,7 +1049,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1373,7 +1373,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1474,7 +1474,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -285,7 +285,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -320,7 +320,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -306,7 +306,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -759,7 +759,7 @@ class BalancerFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x = x.to(torch.float32) x = x.to(torch.float32)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1014,7 +1014,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -1353,7 +1353,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08 coeff = -0.08
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1430,7 +1430,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True

View File

@ -79,7 +79,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -1101,7 +1101,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast( with torch.amp.autocast("cuda",
enabled=params.use_autocast, dtype=params.dtype enabled=params.use_autocast, dtype=params.dtype
): ):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -1438,7 +1438,7 @@ def run(rank, world_size, args):
spec_augment=spec_augment, spec_augment=spec_augment,
) )
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_autocast, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1540,7 +1540,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast( with torch.amp.autocast("cuda",
enabled=params.use_autocast, dtype=params.dtype enabled=params.use_autocast, dtype=params.dtype
): ):
loss, _ = compute_loss( loss, _ = compute_loss(

View File

@ -1873,7 +1873,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -67,7 +67,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -1052,7 +1052,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1397,7 +1397,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1498,7 +1498,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -1916,7 +1916,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -46,7 +46,7 @@ from lhotse.utils import fix_random_seed
from model import CTCModel from model import CTCModel
from optim import Eden, LRScheduler, ScaledAdam from optim import Eden, LRScheduler, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -726,7 +726,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -987,7 +987,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -78,7 +78,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -1065,7 +1065,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1406,7 +1406,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1507,7 +1507,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -307,7 +307,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -863,7 +863,7 @@ class BalancerFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x = x.to(torch.float32) x = x.to(torch.float32)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1118,7 +1118,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -1457,7 +1457,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08 coeff = -0.08
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1534,7 +1534,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True

View File

@ -76,7 +76,7 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -947,7 +947,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1252,7 +1252,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1352,7 +1352,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -1905,7 +1905,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed
from model import CTCModel from model import CTCModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -744,7 +744,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1037,7 +1037,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1138,7 +1138,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -816,7 +816,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1109,7 +1109,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1207,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

Some files were not shown because too many files have changed in this diff Show More