Fix CI warnings

This commit is contained in:
k2-fsa 2025-06-30 21:46:18 +08:00
parent ffe2f16b1d
commit a53c323750
92 changed files with 448 additions and 223 deletions

View File

@ -31,8 +31,8 @@ jobs:
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0")
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0")
echo "::set-output name=matrix::${MATRIX}"
yesno:
needs: generate_build_matrix

View File

@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -638,7 +644,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -688,7 +694,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -184,7 +184,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -219,7 +219,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -94,7 +94,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -797,7 +803,7 @@ def train_one_epoch(
aishell = is_aishell(batch["supervisions"]["cut"][0])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -94,6 +94,7 @@ from icefall.utils import (
filter_uneven_sized_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -87,6 +87,7 @@ from icefall.utils import (
setup_logger,
str2bool,
tokenize_by_CJK_char,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -802,7 +803,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -81,7 +81,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -812,7 +818,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -81,6 +81,7 @@ from icefall.utils import (
filter_uneven_sized_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -514,7 +515,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
@ -608,7 +609,7 @@ def train_one_epoch(
)
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,

View File

@ -95,6 +95,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -910,7 +911,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1302,7 +1303,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -92,6 +92,7 @@ from icefall.utils import (
setup_logger,
str2bool,
tokenize_by_CJK_char,
torch_autocast,
)
@ -495,7 +496,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -90,7 +90,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -734,7 +740,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -83,7 +83,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -727,7 +733,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
# print(batch["supervisions"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -638,7 +644,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -73,7 +73,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -782,7 +788,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -773,7 +779,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -76,7 +76,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1067,7 +1073,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,

View File

@ -76,7 +76,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1058,7 +1064,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,

View File

@ -74,6 +74,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -799,7 +800,7 @@ def train_one_epoch(
num_samples += batch_size
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -88,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -825,7 +826,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -90,6 +90,7 @@ from icefall.utils import (
filter_uneven_sized_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -895,7 +896,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -840,7 +846,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -969,7 +970,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -604,7 +605,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LOG_EPS = math.log(1e-10)
@ -838,7 +844,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -77,7 +77,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -675,7 +681,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -958,7 +959,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1317,7 +1318,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -961,7 +962,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -805,7 +811,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -92,6 +92,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -942,7 +943,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -82,6 +82,7 @@ from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
encode_supervisions,
@ -676,7 +677,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -93,6 +93,7 @@ from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
encode_supervisions,
@ -743,7 +744,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -93,7 +93,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -772,7 +778,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -93,7 +93,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -772,7 +778,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -156,7 +156,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -192,7 +192,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -85,6 +85,7 @@ from icefall.utils import (
display_and_save_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -763,7 +764,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -88,6 +88,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
display_and_save_batch,
@ -848,7 +849,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -80,6 +80,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
display_and_save_batch,
@ -793,7 +794,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -21,7 +21,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -141,7 +141,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -176,7 +176,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -13,6 +13,8 @@ from torch import Tensor, nn
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd
from torch_scheduled_sampling import sample_combined
from icefall.utils import torch_autocast
# The main exports of this file are the module KnowledgeBaseLookup and the
# function create_knowledge_base.
@ -337,7 +339,7 @@ def _test_knowledge_base_lookup_autocast():
for epoch in range(150):
for n, (x, y) in enumerate(train_pairs):
y_out = m(x)
with torch.cuda.amp.autocast(enabled=True):
with torch_autocast(enabled=True):
loss = ((y_out - y) ** 2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")

View File

@ -76,7 +76,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -650,7 +656,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -68,7 +68,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
def add_model_arguments(parser: argparse.ArgumentParser):
@ -693,7 +699,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -157,7 +157,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -193,7 +193,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
display_and_save_batch,
@ -759,7 +760,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -85,6 +85,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
display_and_save_batch,
@ -827,7 +828,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -94,6 +94,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
display_and_save_batch,
@ -789,7 +790,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -82,6 +82,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
display_and_save_batch,
@ -814,7 +815,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -185,7 +185,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -220,7 +220,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -96,6 +96,7 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
torch_autocast,
display_and_save_batch,
setup_logger,
str2bool,
@ -781,7 +782,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -88,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -903,7 +904,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -150,7 +150,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -28,6 +28,8 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding
from icefall.utils import torch_autocast
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
def backward(ctx, x_grad: Tensor):
(x_orig,) = ctx.saved_tensors
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module):
):
return _no_op(x)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
eps = 1.0e-20
orig_x = x
x = x.to(torch.float32)

View File

@ -84,6 +84,7 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
filter_uneven_sized_batch,
@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -44,7 +44,7 @@ from scaling import (
from torch import Tensor, nn
from icefall.dist import get_rank
from icefall.utils import is_jit_tracing, make_pad_mask
from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast
class Zipformer(EncoderInterface):
@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32)
attn_weights_entropy = (

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -150,7 +150,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -89,6 +89,7 @@ from icefall.utils import (
encode_supervisions,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -833,7 +834,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from icefall.utils import add_sos, make_pad_mask, torch_autocast
class Transducer(nn.Module):
@ -178,7 +178,7 @@ class Transducer(nn.Module):
am = self.simple_am_proj(encoder_out_fr)
lm = self.simple_lm_proj(decoder_out)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -213,7 +213,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -85,6 +85,7 @@ from icefall.utils import (
encode_supervisions,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -822,7 +823,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1217,7 +1218,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -82,7 +82,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -810,7 +816,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1224,7 +1230,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -44,7 +44,7 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask
from icefall.utils import make_pad_mask, subsequent_chunk_mask, torch_autocast
def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]:
@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32)
attn_weights_entropy = (

View File

@ -86,7 +86,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -866,7 +872,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1320,7 +1326,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -24,7 +24,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@ -172,7 +172,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -207,7 +207,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -91,7 +91,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -866,7 +872,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1321,7 +1327,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -70,6 +70,7 @@ from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import UniqLexicon
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
encode_supervisions,
@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1198,7 +1199,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -100,6 +100,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1049,7 +1050,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1474,7 +1475,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -25,7 +25,7 @@ from encoder_interface import EncoderInterface
from lhotse.dataset import SpecAugment
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask, time_warp
from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast
class AsrModel(nn.Module):
@ -285,7 +285,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -320,7 +320,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -26,6 +26,8 @@ import torch.nn as nn
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from icefall.utils import torch_autocast
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
max_value = torch.max(x, y)
@ -308,7 +310,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
@ -761,7 +763,7 @@ class BalancerFunction(torch.autograd.Function):
try:
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
x = x.to(torch.float32)
x = x.detach()
x.requires_grad = True
@ -1016,7 +1018,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try:
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
@ -1355,7 +1357,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
@ -1432,7 +1434,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True

View File

@ -96,6 +96,7 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
@ -1101,9 +1102,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
with torch_autocast(enabled=params.use_autocast, dtype=params.dtype):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1551,9 +1550,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
with torch_autocast(enabled=params.use_autocast, dtype=params.dtype):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -47,6 +47,8 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.utils import torch_autocast
class Zipformer2(EncoderInterface):
"""
@ -1873,7 +1875,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -89,6 +89,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1052,7 +1053,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1498,7 +1499,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -50,6 +50,8 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.utils import torch_autocast
class Zipformer2(EncoderInterface):
"""
@ -1916,7 +1918,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -65,7 +65,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
@ -726,7 +732,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,

View File

@ -99,6 +99,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1065,7 +1066,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1507,7 +1508,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -27,6 +27,8 @@ import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from icefall.utils import torch_autocast
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
max_value = torch.max(x, y)
@ -307,7 +309,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
@ -863,7 +865,7 @@ class BalancerFunction(torch.autograd.Function):
try:
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
x = x.to(torch.float32)
x = x.detach()
x.requires_grad = True
@ -1118,7 +1120,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try:
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
@ -1457,7 +1459,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
@ -1534,7 +1536,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -947,7 +948,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1352,7 +1353,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -49,6 +49,8 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.utils import torch_autocast
class Zipformer2(EncoderInterface):
"""
@ -1905,7 +1907,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -90,6 +90,7 @@ from icefall.utils import (
encode_supervisions,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -744,7 +745,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1138,7 +1139,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -86,6 +86,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -816,7 +817,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
@ -816,7 +817,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -24,7 +24,7 @@ import torch
import torch.nn as nn
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class AsrModel(nn.Module):
@ -221,7 +221,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -256,7 +256,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -75,6 +75,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
@ -644,7 +645,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -80,6 +80,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -644,7 +645,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
@ -1115,7 +1116,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1504,7 +1505,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -24,7 +24,7 @@ import torch
import torch.nn as nn
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, torch_autocast
class AsrModel(nn.Module):
@ -221,7 +221,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@ -256,7 +256,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,

View File

@ -78,6 +78,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -944,7 +945,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1334,7 +1335,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -22,6 +22,7 @@ import math
import random
import warnings
from typing import List, Optional, Tuple, Union
from icefall.utils import torch_autocast
import torch
from encoder_interface import EncoderInterface
@ -1849,7 +1850,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -84,6 +84,7 @@ from icefall.utils import (
get_texts,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -757,7 +758,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1076,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -79,6 +79,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler
from icefall.utils import (
torch_autocast,
AttributeDict,
MetricsTracker,
encode_supervisions_otc,
@ -758,7 +759,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -1078,7 +1079,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,

View File

@ -53,7 +53,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
def get_parser():
@ -401,7 +407,7 @@ def compute_validation_loss(
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
@ -470,7 +476,7 @@ def train_one_epoch(
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,

View File

@ -26,6 +26,7 @@ import pathlib
import random
import re
import subprocess
import warnings
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
@ -42,14 +43,40 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
from packaging import version
from pypinyin import lazy_pinyin, pinyin
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
from torch.utils.tensorboard import SummaryWriter
from contextlib import contextmanager
from icefall.checkpoint import average_checkpoints
Pathlike = Union[str, Path]
TORCH_VERSION = version.parse(torch.__version__)
@contextmanager
def torch_autocast(device_type="cuda", **kwargs):
"""
To fix the following warnings:
/icefall/egs/librispeech/ASR/zipformer/model.py:323:
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.
Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=False):
"""
if TORCH_VERSION >= version.parse("2.0.0"):
# Use new unified API
with torch.amp.autocast(device_type=device_type, **kwargs):
yield
else:
# Suppress deprecation warning and use old CUDA-specific autocast
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
with torch.cuda.amp.autocast(**kwargs):
yield
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
# Fixed: https://github.com/pytorch/pytorch/pull/49853
@ -1551,6 +1578,7 @@ def optim_step_and_measure_param_change(
and the L2 norm of the original parameter. It is given by the formula:
.. math::
\begin{aligned}
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
\end{aligned}