Fix CI tests. (#1974)

- Introduce unified AMP helpers (create_grad_scaler, torch_autocast) to handle 
  deprecations in PyTorch ≥2.3.0

- Replace direct uses of torch.cuda.amp.GradScaler and torch.cuda.amp.autocast 
  with the new utilities across all training and inference scripts

- Update all torch.load calls to include weights_only=False for compatibility with 
  newer PyTorch versions
This commit is contained in:
Fangjun Kuang 2025-07-01 13:47:55 +08:00 committed by GitHub
parent 71377d21cd
commit fba5e67d5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
176 changed files with 881 additions and 501 deletions

View File

@ -17,7 +17,7 @@ concurrency:
jobs: jobs:
generate_build_matrix: generate_build_matrix:
if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell') if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633 # see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -31,8 +31,8 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
aishell: aishell:
needs: generate_build_matrix needs: generate_build_matrix

View File

@ -30,8 +30,8 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
audioset: audioset:
@ -83,7 +83,7 @@ jobs:
ls -lh ./model-onnx/* ls -lh ./model-onnx/*
- name: Upload model to huggingface - name: Upload model to huggingface
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3 uses: nick-fields/retry@v3
@ -116,7 +116,7 @@ jobs:
rm -rf huggingface rm -rf huggingface
- name: Prepare for release - name: Prepare for release
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
shell: bash shell: bash
run: | run: |
d=sherpa-onnx-zipformer-audio-tagging-2024-04-09 d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
@ -125,7 +125,7 @@ jobs:
ls -lh ls -lh
- name: Release exported onnx models - name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2 uses: svenstaro/upload-release-action@v2
with: with:
file_glob: true file_glob: true

View File

@ -31,8 +31,8 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3" python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3") MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
baker_zh: baker_zh:
@ -84,43 +84,43 @@ jobs:
ls -lh ls -lh
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav path: ./*.wav
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-2 name: step-2
path: ./model-steps-2.onnx path: ./model-steps-2.onnx
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-3 name: step-3
path: ./model-steps-3.onnx path: ./model-steps-3.onnx
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-4 name: step-4
path: ./model-steps-4.onnx path: ./model-steps-4.onnx
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-5 name: step-5
path: ./model-steps-5.onnx path: ./model-steps-5.onnx
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-6 name: step-6
path: ./model-steps-6.onnx path: ./model-steps-6.onnx
- name: Upload models to huggingface - name: Upload models to huggingface
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
shell: bash shell: bash
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
@ -141,7 +141,7 @@ jobs:
popd popd
- name: Release exported onnx models - name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2 uses: svenstaro/upload-release-action@v2
with: with:
file_glob: true file_glob: true

View File

@ -29,8 +29,9 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
librispeech: librispeech:
needs: generate_build_matrix needs: generate_build_matrix

View File

@ -30,8 +30,8 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3" python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3") MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
ljspeech: ljspeech:
@ -83,13 +83,13 @@ jobs:
ls -lh ls -lh
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav path: ./*.wav
- name: Release exported onnx models - name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2 uses: svenstaro/upload-release-action@v2
with: with:
file_glob: true file_glob: true
@ -100,37 +100,37 @@ jobs:
tag: tts-models tag: tts-models
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-2 name: step-2
path: ./model-steps-2.onnx path: ./model-steps-2.onnx
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-3 name: step-3
path: ./model-steps-3.onnx path: ./model-steps-3.onnx
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-4 name: step-4
path: ./model-steps-4.onnx path: ./model-steps-4.onnx
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-5 name: step-5
path: ./model-steps-5.onnx path: ./model-steps-5.onnx
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
with: with:
name: step-6 name: step-6
path: ./model-steps-6.onnx path: ./model-steps-6.onnx
- name: Upload models to huggingface - name: Upload models to huggingface
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
shell: bash shell: bash
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
@ -155,7 +155,7 @@ jobs:
popd popd
- name: Release exported onnx models - name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
uses: svenstaro/upload-release-action@v2 uses: svenstaro/upload-release-action@v2
with: with:
file_glob: true file_glob: true

View File

@ -30,8 +30,8 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
test: test:
needs: generate_build_matrix needs: generate_build_matrix

View File

@ -30,8 +30,9 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
yesno: yesno:
needs: generate_build_matrix needs: generate_build_matrix

View File

@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -638,7 +644,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -688,7 +694,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -184,7 +184,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -219,7 +219,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -94,7 +94,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -797,7 +803,7 @@ def train_one_epoch(
aishell = is_aishell(batch["supervisions"]["cut"][0]) aishell = is_aishell(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -94,6 +94,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -87,6 +87,7 @@ from icefall.utils import (
setup_logger, setup_logger,
str2bool, str2bool,
tokenize_by_CJK_char, tokenize_by_CJK_char,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -802,7 +803,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -81,7 +81,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -812,7 +818,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -81,6 +81,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -514,7 +515,7 @@ def compute_validation_loss(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -608,7 +609,7 @@ def train_one_epoch(
) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@ -95,6 +95,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -910,7 +911,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1302,7 +1303,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -92,6 +92,7 @@ from icefall.utils import (
setup_logger, setup_logger,
str2bool, str2bool,
tokenize_by_CJK_char, tokenize_by_CJK_char,
torch_autocast,
) )
@ -495,7 +496,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -90,7 +90,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -734,7 +740,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -83,7 +83,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -727,7 +733,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
# print(batch["supervisions"]) # print(batch["supervisions"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -638,7 +644,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -73,7 +73,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -782,7 +788,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -773,7 +779,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -76,7 +76,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1067,7 +1073,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0] batch_size = batch["inputs"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -76,7 +76,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1058,7 +1064,7 @@ def train_one_epoch(
batch_size = batch["inputs"].shape[0] batch_size = batch["inputs"].shape[0]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -74,6 +74,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -799,7 +800,7 @@ def train_one_epoch(
num_samples += batch_size num_samples += batch_size
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -73,6 +73,8 @@ def compute_fbank_baker_zh(num_jobs: int):
f_min=0, f_min=0,
f_max=8000, f_max=8000,
) )
if not torch.cuda.is_available():
config.device = "cpu"
prefix = "baker_zh" prefix = "baker_zh"
suffix = "jsonl.gz" suffix = "jsonl.gz"

View File

@ -88,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -825,7 +826,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -90,6 +90,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -895,7 +896,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -840,7 +846,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -969,7 +970,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -604,7 +605,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -838,7 +844,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -77,7 +77,13 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -675,7 +681,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -958,7 +959,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1317,7 +1318,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -961,7 +962,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -805,7 +811,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -92,6 +92,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -942,7 +943,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -667,7 +667,9 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -707,7 +709,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.method in [ if params.method in [

View File

@ -271,7 +271,7 @@ def main():
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
@ -351,7 +351,9 @@ def main():
"attention-decoder", "attention-decoder",
]: ]:
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder
@ -362,7 +364,9 @@ def main():
"attention-decoder", "attention-decoder",
]: ]:
logging.info(f"Loading G from {params.G}") logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
G = G.to(device) G = G.to(device)

View File

@ -774,7 +774,9 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -814,7 +816,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.method in [ if params.method in [

View File

@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -84,9 +83,11 @@ from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -420,7 +421,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -629,7 +630,7 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -676,7 +677,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -965,7 +966,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -868,7 +868,9 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -907,7 +909,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring": if params.decoding_method == "whole-lattice-rescoring":

View File

@ -334,7 +334,9 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
]: ]:
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder
@ -345,7 +347,9 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
]: ]:
logging.info(f"Loading G from {params.G}") logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device) G = G.to(device)
if params.method == "whole-lattice-rescoring": if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose

View File

@ -290,7 +290,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
@ -386,7 +386,9 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
]: ]:
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder
@ -397,7 +399,9 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
]: ]:
logging.info(f"Loading G from {params.G}") logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device) G = G.to(device)
if params.method == "whole-lattice-rescoring": if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose

View File

@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed
from model import CTCModel from model import CTCModel
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -95,9 +94,11 @@ from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -493,7 +494,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -694,7 +695,7 @@ def train_one_epoch(
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -743,7 +744,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1004,7 +1005,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -574,7 +574,9 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") torch.load(
f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False
)
) )
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -609,7 +611,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False
)
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -93,7 +92,14 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -560,7 +566,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -727,7 +733,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -772,7 +778,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1002,7 +1008,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -93,7 +92,14 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -560,7 +566,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -727,7 +733,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -772,7 +778,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1001,7 +1007,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id) H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"data/lm/{lm}.pt").is_file(): if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}") logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"data/lm/{lm}.pt") d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
else: else:
logging.info(f"Loading {lm}.fst.txt") logging.info(f"Loading {lm}.fst.txt")

View File

@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
An FSA representing LG. An FSA representing LG.
""" """
lexicon = Lexicon(lang_dir) lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"data/lm/{lm}.pt").is_file(): if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}") logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"data/lm/{lm}.pt") d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
else: else:
logging.info(f"Loading {lm}.fst.txt") logging.info(f"Loading {lm}.fst.txt")

View File

@ -750,7 +750,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -156,7 +156,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -192,7 +192,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -238,7 +238,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -66,7 +66,6 @@ from lstm import RNN
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -82,9 +81,11 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -521,7 +522,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -717,7 +718,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -763,7 +764,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1023,7 +1024,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -935,7 +935,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -241,7 +241,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -74,7 +74,6 @@ from lstm import RNN
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -90,9 +89,11 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -560,7 +561,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -772,7 +773,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader, giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random, rng: random.Random,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -848,7 +849,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1176,7 +1177,7 @@ def run(rank, world_size, args):
else: else:
logging.info("Skip scan_pessimistic_batches_for_oom") logging.info("Skip scan_pessimistic_batches_for_oom")
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -815,7 +815,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -239,7 +239,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -66,7 +66,6 @@ from lstm import RNN
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -82,9 +81,11 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -551,7 +552,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -747,7 +748,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -793,7 +794,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1067,7 +1068,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_.autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -21,7 +21,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -141,7 +141,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -176,7 +176,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -10,9 +10,11 @@ from typing import Optional, Tuple
import torch import torch
from scaling import ScaledLinear from scaling import ScaledLinear
from torch import Tensor, nn from torch import Tensor, nn
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from torch_scheduled_sampling import sample_combined from torch_scheduled_sampling import sample_combined
from icefall.utils import create_grad_scaler, torch_autocast
# The main exports of this file are the module KnowledgeBaseLookup and the # The main exports of this file are the module KnowledgeBaseLookup and the
# function create_knowledge_base. # function create_knowledge_base.
@ -330,14 +332,14 @@ def _test_knowledge_base_lookup_autocast():
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
m = m.to(device) m = m.to(device)
scaler = GradScaler(enabled=True) scaler = create_grad_scaler(enabled=True)
start = timeit.default_timer() start = timeit.default_timer()
for epoch in range(150): for epoch in range(150):
for n, (x, y) in enumerate(train_pairs): for n, (x, y) in enumerate(train_pairs):
y_out = m(x) y_out = m(x)
with torch.cuda.amp.autocast(enabled=True): with torch_autocast(enabled=True):
loss = ((y_out - y) ** 2).mean() * 100.0 loss = ((y_out - y) ** 2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0: if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")

View File

@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -76,7 +75,14 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
create_grad_scaler,
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -453,7 +459,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -608,7 +614,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -650,7 +656,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -868,7 +874,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from noam import Noam from noam import Noam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -68,7 +67,14 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
create_grad_scaler,
setup_logger,
str2bool,
torch_autocast,
)
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -496,7 +502,7 @@ def save_checkpoint(
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, and training stats to file. """Save model, optimizer, and training stats to file.
@ -650,7 +656,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -693,7 +699,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -939,7 +945,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -741,7 +741,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -754,7 +754,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -157,7 +157,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -193,7 +193,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -265,7 +265,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -91,9 +90,11 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -523,7 +524,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -716,7 +717,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -759,7 +760,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1000,7 +1001,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0, warmup=0.0 if params.start_epoch == 0 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -921,7 +921,7 @@ def load_ngram_LM(
if pt_file.is_file(): if pt_file.is_file():
logging.info(f"Loading pre-compiled {pt_file}") logging.info(f"Loading pre-compiled {pt_file}")
d = torch.load(pt_file, map_location=device) d = torch.load(pt_file, map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
G = k2.add_epsilon_self_loops(G) G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G) G = k2.arc_sort(G)
@ -1101,7 +1101,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
elif params.decoding_method in [ elif params.decoding_method in [

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -274,7 +274,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -74,7 +74,6 @@ from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -87,9 +86,11 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -546,7 +547,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -755,7 +756,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader, giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random, rng: random.Random,
scaler: GradScaler, scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -827,7 +828,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0]) libri = is_libri(batch["supervisions"]["cut"][0])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1126,7 +1127,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0, warmup=0.0 if params.start_epoch == 0 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -913,7 +913,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -96,9 +95,11 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -548,7 +549,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -744,7 +745,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -789,7 +790,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1047,7 +1048,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -972,7 +972,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -238,7 +238,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -84,9 +83,11 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -571,7 +572,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -768,7 +769,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -814,7 +815,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1078,7 +1079,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -185,7 +185,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -220,7 +220,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -96,9 +95,11 @@ from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
display_and_save_batch, display_and_save_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -519,7 +520,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -736,7 +737,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -781,7 +782,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1039,7 +1040,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0, warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -348,7 +348,9 @@ class CodebookIndexExtractor:
num_codebooks=self.params.num_codebooks, num_codebooks=self.params.num_codebooks,
codebook_size=256, codebook_size=256,
) )
quantizer.load_state_dict(torch.load(self.quantizer_file_path)) quantizer.load_state_dict(
torch.load(self.quantizer_file_path, weights_only=False)
)
quantizer.to(self.params.device) quantizer.to(self.params.device)
return quantizer return quantizer

View File

@ -289,7 +289,7 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -910,7 +910,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -813,7 +813,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -85,9 +84,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -635,7 +636,7 @@ def load_model_params(
""" """
logging.info(f"Loading checkpoint from {ckpt}") logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu") checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt # if module list is empty, load the whole model from ckpt
if not init_modules: if not init_modules:
@ -678,7 +679,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -857,7 +858,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -903,7 +904,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1219,7 +1220,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt from scaling import penalize_abs_values_gt
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -150,7 +150,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25: # if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04) # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -28,6 +28,8 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding from torch.nn import Embedding as ScaledEmbedding
from icefall.utils import torch_autocast
class ActivationBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32) ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32) ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
def backward(ctx, x_grad: Tensor): def backward(ctx, x_grad: Tensor):
(x_orig,) = ctx.saved_tensors (x_orig,) = ctx.saved_tensors
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True x_detached.requires_grad = True
@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module):
): ):
return _no_op(x) return _no_op(x)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
eps = 1.0e-20 eps = 1.0e-20
orig_x = x orig_x = x
x = x.to(torch.float32) x = x.to(torch.float32)

View File

@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -86,10 +85,12 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
symlink_or_copy, symlink_or_copy,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -581,7 +582,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -763,7 +764,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1106,7 +1107,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -44,7 +44,7 @@ from scaling import (
from torch import Tensor, nn from torch import Tensor, nn
from icefall.dist import get_rank from icefall.dist import get_rank
from icefall.utils import is_jit_tracing, make_pad_mask from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast
class Zipformer(EncoderInterface): class Zipformer(EncoderInterface):
@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads bsz = n // num_heads
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32) attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32) attn_output = attn_output.to(torch.float32)
attn_weights_entropy = ( attn_weights_entropy = (

View File

@ -633,7 +633,9 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -672,7 +674,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring": if params.decoding_method == "whole-lattice-rescoring":

View File

@ -786,7 +786,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

View File

@ -347,7 +347,9 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
]: ]:
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder
@ -358,7 +360,9 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
]: ]:
logging.info(f"Loading G from {params.G}") logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device) G = G.to(device)
if params.method == "whole-lattice-rescoring": if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module): class Transducer(nn.Module):
@ -150,7 +150,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(), logits=logits.float(),
symbols=y_padded, symbols=y_padded,

View File

@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -286,7 +286,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
@ -365,7 +365,9 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
]: ]:
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(
torch.load(params.HLG, map_location="cpu", weights_only=False)
)
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder
@ -376,7 +378,9 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
]: ]:
logging.info(f"Loading G from {params.G}") logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) G = k2.Fsa.from_dict(
torch.load(params.G, map_location="cpu", weights_only=False)
)
G = G.to(device) G = G.to(device)
if params.method == "whole-lattice-rescoring": if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose

View File

@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
create_grad_scaler,
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
torch_autocast,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -588,7 +589,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional["GradScaler"] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -787,7 +788,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: "GradScaler",
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -833,7 +834,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1128,7 +1129,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

View File

@ -624,7 +624,9 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) torch.load(
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
)
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -663,7 +665,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring": if params.decoding_method == "whole-lattice-rescoring":

View File

@ -808,7 +808,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device, weights_only=False)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:

Some files were not shown because too many files have changed in this diff Show More