mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix CI warnings
This commit is contained in:
parent
ffe2f16b1d
commit
a53c323750
4
.github/workflows/yesno.yml
vendored
4
.github/workflows/yesno.yml
vendored
@ -31,8 +31,8 @@ jobs:
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0")
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
yesno:
|
||||
needs: generate_build_matrix
|
||||
|
@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -638,7 +644,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -688,7 +694,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -184,7 +184,7 @@ class Transducer(nn.Module):
|
||||
lm = simple_lm_proj(decoder_out)
|
||||
am = simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -219,7 +219,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -94,7 +94,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -797,7 +803,7 @@ def train_one_epoch(
|
||||
aishell = is_aishell(batch["supervisions"]["cut"][0])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -94,6 +94,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -809,7 +810,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -87,6 +87,7 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
str2bool,
|
||||
tokenize_by_CJK_char,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -802,7 +803,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -81,7 +81,13 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -812,7 +818,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -81,6 +81,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -514,7 +515,7 @@ def compute_validation_loss(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
@ -608,7 +609,7 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -95,6 +95,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -910,7 +911,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1302,7 +1303,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -92,6 +92,7 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
str2bool,
|
||||
tokenize_by_CJK_char,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
|
||||
@ -495,7 +496,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -90,7 +90,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -734,7 +740,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -83,7 +83,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -727,7 +733,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
# print(batch["supervisions"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -638,7 +644,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -73,7 +73,13 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -782,7 +788,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -773,7 +779,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -76,7 +76,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -1067,7 +1073,7 @@ def train_one_epoch(
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -76,7 +76,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -1058,7 +1064,7 @@ def train_one_epoch(
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -74,6 +74,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -799,7 +800,7 @@ def train_one_epoch(
|
||||
num_samples += batch_size
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -88,6 +88,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -825,7 +826,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -90,6 +90,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -895,7 +896,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -840,7 +846,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -969,7 +970,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -604,7 +605,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -838,7 +844,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -77,7 +77,13 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -675,7 +681,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -958,7 +959,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1317,7 +1318,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -961,7 +962,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -805,7 +811,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -92,6 +92,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -942,7 +943,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -82,6 +82,7 @@ from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
@ -676,7 +677,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -93,6 +93,7 @@ from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
@ -743,7 +744,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -93,7 +93,13 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -772,7 +778,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -93,7 +93,13 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -772,7 +778,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -156,7 +156,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -192,7 +192,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -85,6 +85,7 @@ from icefall.utils import (
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -763,7 +764,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -195,7 +195,7 @@ class Transducer(nn.Module):
|
||||
lm = simple_lm_proj(decoder_out)
|
||||
am = simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -231,7 +231,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -88,6 +88,7 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
@ -848,7 +849,7 @@ def train_one_epoch(
|
||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -80,6 +80,7 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
@ -793,7 +794,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -141,7 +141,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -176,7 +176,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -13,6 +13,8 @@ from torch import Tensor, nn
|
||||
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd
|
||||
from torch_scheduled_sampling import sample_combined
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
# The main exports of this file are the module KnowledgeBaseLookup and the
|
||||
# function create_knowledge_base.
|
||||
|
||||
@ -337,7 +339,7 @@ def _test_knowledge_base_lookup_autocast():
|
||||
for epoch in range(150):
|
||||
for n, (x, y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch_autocast(enabled=True):
|
||||
loss = ((y_out - y) ** 2).mean() * 100.0
|
||||
if n % 10 == 0 and epoch % 10 == 0:
|
||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||
|
@ -76,7 +76,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -650,7 +656,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -68,7 +68,13 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
@ -693,7 +699,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -157,7 +157,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -193,7 +193,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
@ -759,7 +760,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -195,7 +195,7 @@ class Transducer(nn.Module):
|
||||
lm = simple_lm_proj(decoder_out)
|
||||
am = simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -231,7 +231,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
@ -827,7 +828,7 @@ def train_one_epoch(
|
||||
|
||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -94,6 +94,7 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
@ -789,7 +790,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -82,6 +82,7 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
display_and_save_batch,
|
||||
@ -814,7 +815,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -220,7 +220,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -96,6 +96,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
torch_autocast,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -781,7 +782,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -88,6 +88,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -903,7 +904,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -150,7 +150,7 @@ class Transducer(nn.Module):
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -28,6 +28,8 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import Embedding as ScaledEmbedding
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
(ans,) = ctx.saved_tensors
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
def backward(ctx, x_grad: Tensor):
|
||||
(x_orig,) = ctx.saved_tensors
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
|
||||
@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module):
|
||||
):
|
||||
return _no_op(x)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
eps = 1.0e-20
|
||||
orig_x = x
|
||||
x = x.to(torch.float32)
|
||||
|
@ -84,6 +84,7 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
filter_uneven_sized_batch,
|
||||
@ -809,7 +810,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -44,7 +44,7 @@ from scaling import (
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.dist import get_rank
|
||||
from icefall.utils import is_jit_tracing, make_pad_mask
|
||||
from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast
|
||||
|
||||
|
||||
class Zipformer(EncoderInterface):
|
||||
@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz = n // num_heads
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_output = attn_output.to(torch.float32)
|
||||
attn_weights_entropy = (
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -150,7 +150,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -89,6 +89,7 @@ from icefall.utils import (
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -833,7 +834,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from icefall.utils import add_sos, make_pad_mask, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -178,7 +178,7 @@ class Transducer(nn.Module):
|
||||
am = self.simple_am_proj(encoder_out_fr)
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -213,7 +213,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -85,6 +85,7 @@ from icefall.utils import (
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -822,7 +823,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1217,7 +1218,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -82,7 +82,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -810,7 +816,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1224,7 +1230,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -44,7 +44,7 @@ from scaling import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask, torch_autocast
|
||||
|
||||
|
||||
def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]:
|
||||
@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz = n // num_heads
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_output = attn_output.to(torch.float32)
|
||||
attn_weights_entropy = (
|
||||
|
@ -86,7 +86,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -866,7 +872,7 @@ def train_one_epoch(
|
||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1320,7 +1326,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -24,7 +24,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -172,7 +172,7 @@ class Transducer(nn.Module):
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -207,7 +207,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -91,7 +91,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -866,7 +872,7 @@ def train_one_epoch(
|
||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1321,7 +1327,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -70,6 +70,7 @@ from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import UniqLexicon
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
@ -809,7 +810,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1198,7 +1199,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -100,6 +100,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1049,7 +1050,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1474,7 +1475,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -25,7 +25,7 @@ from encoder_interface import EncoderInterface
|
||||
from lhotse.dataset import SpecAugment
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask, time_warp
|
||||
from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
@ -285,7 +285,7 @@ class AsrModel(nn.Module):
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -320,7 +320,7 @@ class AsrModel(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -26,6 +26,8 @@ import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
|
||||
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
||||
max_value = torch.max(x, y)
|
||||
@ -308,7 +310,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
(ans,) = ctx.saved_tensors
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
@ -761,7 +763,7 @@ class BalancerFunction(torch.autograd.Function):
|
||||
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
x = x.to(torch.float32)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
@ -1016,7 +1018,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
|
||||
@ -1355,7 +1357,7 @@ class SwooshLFunction(torch.autograd.Function):
|
||||
|
||||
coeff = -0.08
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
with torch.enable_grad():
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
@ -1432,7 +1434,7 @@ class SwooshRFunction(torch.autograd.Function):
|
||||
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
with torch.enable_grad():
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
|
@ -96,6 +96,7 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
@ -1101,9 +1102,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
):
|
||||
with torch_autocast(enabled=params.use_autocast, dtype=params.dtype):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1551,9 +1550,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
):
|
||||
with torch_autocast(enabled=params.use_autocast, dtype=params.dtype):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -47,6 +47,8 @@ from scaling import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
|
||||
class Zipformer2(EncoderInterface):
|
||||
"""
|
||||
@ -1873,7 +1875,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_weights_entropy = (
|
||||
-((attn_weights + 1.0e-20).log() * attn_weights)
|
||||
|
@ -89,6 +89,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1052,7 +1053,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1498,7 +1499,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -50,6 +50,8 @@ from scaling import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
|
||||
class Zipformer2(EncoderInterface):
|
||||
"""
|
||||
@ -1916,7 +1918,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_weights_entropy = (
|
||||
-((attn_weights + 1.0e-20).log() * attn_weights)
|
||||
|
@ -65,7 +65,13 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
|
||||
|
||||
@ -726,7 +732,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -99,6 +99,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1065,7 +1066,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1507,7 +1508,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -27,6 +27,8 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
|
||||
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
||||
max_value = torch.max(x, y)
|
||||
@ -307,7 +309,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
(ans,) = ctx.saved_tensors
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
@ -863,7 +865,7 @@ class BalancerFunction(torch.autograd.Function):
|
||||
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
x = x.to(torch.float32)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
@ -1118,7 +1120,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
|
||||
@ -1457,7 +1459,7 @@ class SwooshLFunction(torch.autograd.Function):
|
||||
|
||||
coeff = -0.08
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
with torch.enable_grad():
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
@ -1534,7 +1536,7 @@ class SwooshRFunction(torch.autograd.Function):
|
||||
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
with torch.enable_grad():
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -947,7 +948,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1352,7 +1353,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -49,6 +49,8 @@ from scaling import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
|
||||
class Zipformer2(EncoderInterface):
|
||||
"""
|
||||
@ -1905,7 +1907,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_weights_entropy = (
|
||||
-((attn_weights + 1.0e-20).log() * attn_weights)
|
||||
|
@ -90,6 +90,7 @@ from icefall.utils import (
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -744,7 +745,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1138,7 +1139,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -86,6 +86,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -816,7 +817,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
@ -816,7 +817,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -24,7 +24,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
@ -221,7 +221,7 @@ class AsrModel(nn.Module):
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -256,7 +256,7 @@ class AsrModel(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -75,6 +75,7 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
@ -644,7 +645,7 @@ def train_one_epoch(
|
||||
batch_size = batch["kmeans"].shape[0]
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -80,6 +80,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -644,7 +645,7 @@ def train_one_epoch(
|
||||
batch_size = batch["kmeans"].shape[0]
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
@ -1115,7 +1116,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1504,7 +1505,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -24,7 +24,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
@ -221,7 +221,7 @@ class AsrModel(nn.Module):
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -256,7 +256,7 @@ class AsrModel(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -78,6 +78,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -944,7 +945,7 @@ def train_one_epoch(
|
||||
batch_size = batch["kmeans"].shape[0]
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1334,7 +1335,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -22,6 +22,7 @@ import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
@ -1849,7 +1850,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_weights_entropy = (
|
||||
-((attn_weights + 1.0e-20).log() * attn_weights)
|
||||
|
@ -84,6 +84,7 @@ from icefall.utils import (
|
||||
get_texts,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -757,7 +758,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1076,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -79,6 +79,7 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
torch_autocast,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions_otc,
|
||||
@ -758,7 +759,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1078,7 +1079,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -53,7 +53,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -401,7 +407,7 @@ def compute_validation_loss(
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
x, y, sentence_lengths = batch
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
model=model,
|
||||
x=x,
|
||||
@ -470,7 +476,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
x, y, sentence_lengths = batch
|
||||
batch_size = x.size(0)
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
model=model,
|
||||
x=x,
|
||||
|
@ -26,6 +26,7 @@ import pathlib
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
@ -42,14 +43,40 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
|
||||
from packaging import version
|
||||
from pypinyin import lazy_pinyin, pinyin
|
||||
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
from icefall.checkpoint import average_checkpoints
|
||||
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
TORCH_VERSION = version.parse(torch.__version__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_autocast(device_type="cuda", **kwargs):
|
||||
"""
|
||||
To fix the following warnings:
|
||||
/icefall/egs/librispeech/ASR/zipformer/model.py:323:
|
||||
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.
|
||||
Please use `torch.amp.autocast('cuda', args...)` instead.
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
"""
|
||||
if TORCH_VERSION >= version.parse("2.0.0"):
|
||||
# Use new unified API
|
||||
with torch.amp.autocast(device_type=device_type, **kwargs):
|
||||
yield
|
||||
else:
|
||||
# Suppress deprecation warning and use old CUDA-specific autocast
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=FutureWarning)
|
||||
with torch.cuda.amp.autocast(**kwargs):
|
||||
yield
|
||||
|
||||
|
||||
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
|
||||
# Fixed: https://github.com/pytorch/pytorch/pull/49853
|
||||
@ -1551,6 +1578,7 @@ def optim_step_and_measure_param_change(
|
||||
and the L2 norm of the original parameter. It is given by the formula:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
|
||||
\end{aligned}
|
||||
|
Loading…
x
Reference in New Issue
Block a user