Fix CI warnings

This commit is contained in:
k2-fsa 2025-06-30 21:46:18 +08:00
parent ffe2f16b1d
commit a53c323750
92 changed files with 448 additions and 223 deletions

View File

@ -31,8 +31,8 @@ jobs:
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0") # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
yesno: yesno:
needs: generate_build_matrix needs: generate_build_matrix

View File

@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -638,7 +644,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -688,7 +694,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -184,7 +184,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -219,7 +219,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -94,7 +94,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -797,7 +803,7 @@ def train_one_epoch(
aishell = is_aishell(batch["supervisions"]["cut"][0]) aishell = is_aishell(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -94,6 +94,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -87,6 +87,7 @@ from icefall.utils import (
setup_logger, setup_logger,
str2bool, str2bool,
tokenize_by_CJK_char, tokenize_by_CJK_char,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -802,7 +803,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -81,7 +81,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -812,7 +818,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -81,6 +81,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -514,7 +515,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -608,7 +609,7 @@ def train_one_epoch(
) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@ -95,6 +95,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -910,7 +911,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1302,7 +1303,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -92,6 +92,7 @@ from icefall.utils import (
setup_logger, setup_logger,
str2bool, str2bool,
tokenize_by_CJK_char, tokenize_by_CJK_char,
torch_autocast,
) )
@ -495,7 +496,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -90,7 +90,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -734,7 +740,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -83,7 +83,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -727,7 +733,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
# print(batch["supervisions"]) # print(batch["supervisions"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -638,7 +644,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -73,7 +73,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -782,7 +788,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -773,7 +779,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -76,7 +76,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1067,7 +1073,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0] batch_size = batch["inputs"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -76,7 +76,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1058,7 +1064,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0] batch_size = batch["inputs"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -74,6 +74,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -799,7 +800,7 @@ def train_one_epoch(
num_samples += batch_size num_samples += batch_size
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -88,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -825,7 +826,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -90,6 +90,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -895,7 +896,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -840,7 +846,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -969,7 +970,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -604,7 +605,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -838,7 +844,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -77,7 +77,13 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -675,7 +681,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -958,7 +959,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1317,7 +1318,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -961,7 +962,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -805,7 +811,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -92,6 +92,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -942,7 +943,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -82,6 +82,7 @@ from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
@ -676,7 +677,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -93,6 +93,7 @@ from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
@ -743,7 +744,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -93,7 +93,13 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -772,7 +778,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -93,7 +93,13 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -772,7 +778,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -156,7 +156,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -192,7 +192,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -85,6 +85,7 @@ from icefall.utils import (
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -763,7 +764,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -88,6 +88,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
display_and_save_batch, display_and_save_batch,
@ -848,7 +849,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -80,6 +80,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
display_and_save_batch, display_and_save_batch,
@ -793,7 +794,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_.autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -21,7 +21,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -141,7 +141,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -176,7 +176,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -13,6 +13,8 @@ from torch import Tensor, nn
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd
from torch_scheduled_sampling import sample_combined from torch_scheduled_sampling import sample_combined
from icefall.utils import torch_autocast
# The main exports of this file are the module KnowledgeBaseLookup and the # The main exports of this file are the module KnowledgeBaseLookup and the
# function create_knowledge_base. # function create_knowledge_base.
@ -337,7 +339,7 @@ def _test_knowledge_base_lookup_autocast():
for epoch in range(150): for epoch in range(150):
for n, (x, y) in enumerate(train_pairs): for n, (x, y) in enumerate(train_pairs):
y_out = m(x) y_out = m(x)
with torch.cuda.amp.autocast(enabled=True): with torch_autocast(enabled=True):
loss = ((y_out - y) ** 2).mean() * 100.0 loss = ((y_out - y) ** 2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0: if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")

View File

@ -76,7 +76,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -650,7 +656,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -68,7 +68,13 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -693,7 +699,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -157,7 +157,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -193,7 +193,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
display_and_save_batch, display_and_save_batch,
@ -759,7 +760,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -85,6 +85,7 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
display_and_save_batch, display_and_save_batch,
@ -827,7 +828,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -94,6 +94,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
display_and_save_batch, display_and_save_batch,
@ -789,7 +790,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -82,6 +82,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
display_and_save_batch, display_and_save_batch,
@ -814,7 +815,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -185,7 +185,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -220,7 +220,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -96,6 +96,7 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
torch_autocast,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
@ -781,7 +782,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -88,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -903,7 +904,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt from scaling import penalize_abs_values_gt
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -150,7 +150,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -28,6 +28,8 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding from torch.nn import Embedding as ScaledEmbedding
from icefall.utils import torch_autocast
class ActivationBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
def backward(ctx, x_grad: Tensor): def backward(ctx, x_grad: Tensor):
(x_orig,) = ctx.saved_tensors (x_orig,) = ctx.saved_tensors
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module):
): ):
return _no_op(x) return _no_op(x)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
eps = 1.0e-20 eps = 1.0e-20
orig_x = x orig_x = x
x = x.to(torch.float32) x = x.to(torch.float32)

View File

@ -84,6 +84,7 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
filter_uneven_sized_batch, filter_uneven_sized_batch,
@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -44,7 +44,7 @@ from scaling import (
from torch import Tensor, nn from torch import Tensor, nn
from icefall.dist import get_rank from icefall.dist import get_rank
from icefall.utils import is_jit_tracing, make_pad_mask from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast
class Zipformer(EncoderInterface): class Zipformer(EncoderInterface):
@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads bsz = n // num_heads
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32) attn_output = attn_output.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -150,7 +150,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -89,6 +89,7 @@ from icefall.utils import (
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -833,7 +834,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -178,7 +178,7 @@ class Transducer(nn.Module):
am = self.simple_am_proj(encoder_out_fr) am = self.simple_am_proj(encoder_out_fr)
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -213,7 +213,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -85,6 +85,7 @@ from icefall.utils import (
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -822,7 +823,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1217,7 +1218,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -82,7 +82,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -810,7 +816,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1224,7 +1230,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -44,7 +44,7 @@ from scaling import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask from icefall.utils import make_pad_mask, subsequent_chunk_mask, torch_autocast
def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]:
@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads bsz = n // num_heads
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32) attn_output = attn_output.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (

View File

@ -86,7 +86,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -866,7 +872,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1320,7 +1326,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -24,7 +24,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt from scaling import penalize_abs_values_gt
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -172,7 +172,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -207,7 +207,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -91,7 +91,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -866,7 +872,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1321,7 +1327,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -70,6 +70,7 @@ from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import UniqLexicon from icefall.lexicon import UniqLexicon
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1198,7 +1199,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -100,6 +100,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1049,7 +1050,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1474,7 +1475,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -25,7 +25,7 @@ from encoder_interface import EncoderInterface
from lhotse.dataset import SpecAugment from lhotse.dataset import SpecAugment
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask, time_warp from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -285,7 +285,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -320,7 +320,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -26,6 +26,8 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from icefall.utils import torch_autocast
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
max_value = torch.max(x, y) max_value = torch.max(x, y)
@ -308,7 +310,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -761,7 +763,7 @@ class BalancerFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
x = x.to(torch.float32) x = x.to(torch.float32)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1016,7 +1018,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -1355,7 +1357,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08 coeff = -0.08
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1432,7 +1434,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True

View File

@ -96,6 +96,7 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
@ -1101,9 +1102,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast( with torch_autocast(enabled=params.use_autocast, dtype=params.dtype):
enabled=params.use_autocast, dtype=params.dtype
):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1551,9 +1550,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast( with torch_autocast(enabled=params.use_autocast, dtype=params.dtype):
enabled=params.use_autocast, dtype=params.dtype
):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -47,6 +47,8 @@ from scaling import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import torch_autocast
class Zipformer2(EncoderInterface): class Zipformer2(EncoderInterface):
""" """
@ -1873,7 +1875,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -89,6 +89,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1052,7 +1053,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1498,7 +1499,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -50,6 +50,8 @@ from scaling import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import torch_autocast
class Zipformer2(EncoderInterface): class Zipformer2(EncoderInterface):
""" """
@ -1916,7 +1918,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -65,7 +65,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
@ -726,7 +732,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -99,6 +99,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1065,7 +1066,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1507,7 +1508,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -27,6 +27,8 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from icefall.utils import torch_autocast
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
max_value = torch.max(x, y) max_value = torch.max(x, y)
@ -307,7 +309,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -863,7 +865,7 @@ class BalancerFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
x = x.to(torch.float32) x = x.to(torch.float32)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1118,7 +1120,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -1457,7 +1459,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08 coeff = -0.08
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
@ -1534,7 +1536,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
with torch.enable_grad(): with torch.enable_grad():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -947,7 +948,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1352,7 +1353,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -49,6 +49,8 @@ from scaling import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import torch_autocast
class Zipformer2(EncoderInterface): class Zipformer2(EncoderInterface):
""" """
@ -1905,7 +1907,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -90,6 +90,7 @@ from icefall.utils import (
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -744,7 +745,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1138,7 +1139,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -86,6 +86,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -816,7 +817,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
@ -816,7 +817,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -24,7 +24,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -221,7 +221,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -256,7 +256,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -75,6 +75,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
@ -644,7 +645,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0] batch_size = batch["kmeans"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -80,6 +80,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -644,7 +645,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0] batch_size = batch["kmeans"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
@ -1115,7 +1116,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1504,7 +1505,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -24,7 +24,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -221,7 +221,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -256,7 +256,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -78,6 +78,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -944,7 +945,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0] batch_size = batch["kmeans"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1334,7 +1335,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -22,6 +22,7 @@ import math
import random import random
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from icefall.utils import torch_autocast
import torch import torch
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
@ -1849,7 +1850,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights) -((attn_weights + 1.0e-20).log() * attn_weights)

View File

@ -84,6 +84,7 @@ from icefall.utils import (
get_texts, get_texts,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -757,7 +758,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1076,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -79,6 +79,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler
from icefall.utils import ( from icefall.utils import (
torch_autocast,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions_otc, encode_supervisions_otc,
@ -758,7 +759,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1078,7 +1079,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -53,7 +53,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
def get_parser(): def get_parser():
@ -401,7 +407,7 @@ def compute_validation_loss(
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch x, y, sentence_lengths = batch
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
model=model, model=model,
x=x, x=x,
@ -470,7 +476,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
x, y, sentence_lengths = batch x, y, sentence_lengths = batch
batch_size = x.size(0) batch_size = x.size(0)
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
model=model, model=model,
x=x, x=x,

View File

@ -26,6 +26,7 @@ import pathlib
import random import random
import re import re
import subprocess import subprocess
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
@ -42,14 +43,40 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
from packaging import version
from pypinyin import lazy_pinyin, pinyin from pypinyin import lazy_pinyin, pinyin
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from contextlib import contextmanager
from icefall.checkpoint import average_checkpoints from icefall.checkpoint import average_checkpoints
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
TORCH_VERSION = version.parse(torch.__version__)
@contextmanager
def torch_autocast(device_type="cuda", **kwargs):
"""
To fix the following warnings:
/icefall/egs/librispeech/ASR/zipformer/model.py:323:
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.
Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=False):
"""
if TORCH_VERSION >= version.parse("2.0.0"):
# Use new unified API
with torch.amp.autocast(device_type=device_type, **kwargs):
yield
else:
# Suppress deprecation warning and use old CUDA-specific autocast
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
with torch.cuda.amp.autocast(**kwargs):
yield
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379 # Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
# Fixed: https://github.com/pytorch/pytorch/pull/49853 # Fixed: https://github.com/pytorch/pytorch/pull/49853
@ -1551,6 +1578,7 @@ def optim_step_and_measure_param_change(
and the L2 norm of the original parameter. It is given by the formula: and the L2 norm of the original parameter. It is given by the formula:
.. math:: .. math::
\begin{aligned} \begin{aligned}
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
\end{aligned} \end{aligned}