mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix CI tests. (#1974)
- Introduce unified AMP helpers (create_grad_scaler, torch_autocast) to handle deprecations in PyTorch ≥2.3.0 - Replace direct uses of torch.cuda.amp.GradScaler and torch.cuda.amp.autocast with the new utilities across all training and inference scripts - Update all torch.load calls to include weights_only=False for compatibility with newer PyTorch versions
This commit is contained in:
parent
71377d21cd
commit
fba5e67d5e
6
.github/workflows/aishell.yml
vendored
6
.github/workflows/aishell.yml
vendored
@ -17,7 +17,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
generate_build_matrix:
|
||||
if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell')
|
||||
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||
|
||||
# see https://github.com/pytorch/pytorch/pull/50633
|
||||
runs-on: ubuntu-latest
|
||||
@ -31,8 +31,8 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
aishell:
|
||||
needs: generate_build_matrix
|
||||
|
10
.github/workflows/audioset.yml
vendored
10
.github/workflows/audioset.yml
vendored
@ -30,8 +30,8 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
|
||||
audioset:
|
||||
@ -83,7 +83,7 @@ jobs:
|
||||
ls -lh ./model-onnx/*
|
||||
|
||||
- name: Upload model to huggingface
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
@ -116,7 +116,7 @@ jobs:
|
||||
rm -rf huggingface
|
||||
|
||||
- name: Prepare for release
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
shell: bash
|
||||
run: |
|
||||
d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||
@ -125,7 +125,7 @@ jobs:
|
||||
ls -lh
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
|
20
.github/workflows/baker_zh.yml
vendored
20
.github/workflows/baker_zh.yml
vendored
@ -31,8 +31,8 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3")
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
|
||||
baker_zh:
|
||||
@ -84,43 +84,43 @@ jobs:
|
||||
ls -lh
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-2
|
||||
path: ./model-steps-2.onnx
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-3
|
||||
path: ./model-steps-3.onnx
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-4
|
||||
path: ./model-steps-4.onnx
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-5
|
||||
path: ./model-steps-5.onnx
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-6
|
||||
path: ./model-steps-6.onnx
|
||||
|
||||
- name: Upload models to huggingface
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
shell: bash
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
@ -141,7 +141,7 @@ jobs:
|
||||
popd
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
|
5
.github/workflows/librispeech.yml
vendored
5
.github/workflows/librispeech.yml
vendored
@ -29,8 +29,9 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
librispeech:
|
||||
needs: generate_build_matrix
|
||||
|
22
.github/workflows/ljspeech.yml
vendored
22
.github/workflows/ljspeech.yml
vendored
@ -30,8 +30,8 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3")
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
|
||||
ljspeech:
|
||||
@ -83,13 +83,13 @@ jobs:
|
||||
ls -lh
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
@ -100,37 +100,37 @@ jobs:
|
||||
tag: tts-models
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-2
|
||||
path: ./model-steps-2.onnx
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-3
|
||||
path: ./model-steps-3.onnx
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-4
|
||||
path: ./model-steps-4.onnx
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-5
|
||||
path: ./model-steps-5.onnx
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: step-6
|
||||
path: ./model-steps-6.onnx
|
||||
|
||||
- name: Upload models to huggingface
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
shell: bash
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
@ -155,7 +155,7 @@ jobs:
|
||||
popd
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
|
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@ -30,8 +30,8 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
test:
|
||||
needs: generate_build_matrix
|
||||
|
5
.github/workflows/yesno.yml
vendored
5
.github/workflows/yesno.yml
vendored
@ -30,8 +30,9 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0")
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
yesno:
|
||||
needs: generate_build_matrix
|
||||
|
@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -638,7 +644,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -688,7 +694,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -184,7 +184,7 @@ class Transducer(nn.Module):
|
||||
lm = simple_lm_proj(decoder_out)
|
||||
am = simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -219,7 +219,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -94,7 +94,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -797,7 +803,7 @@ def train_one_epoch(
|
||||
aishell = is_aishell(batch["supervisions"]["cut"][0])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -94,6 +94,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -809,7 +810,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -87,6 +87,7 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
str2bool,
|
||||
tokenize_by_CJK_char,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -802,7 +803,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -81,7 +81,13 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -812,7 +818,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -81,6 +81,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -514,7 +515,7 @@ def compute_validation_loss(
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
@ -608,7 +609,7 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -95,6 +95,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -910,7 +911,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1302,7 +1303,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -92,6 +92,7 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
str2bool,
|
||||
tokenize_by_CJK_char,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
|
||||
@ -495,7 +496,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -90,7 +90,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -734,7 +740,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -83,7 +83,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -727,7 +733,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
# print(batch["supervisions"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -638,7 +644,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -73,7 +73,13 @@ from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -782,7 +788,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -773,7 +779,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -76,7 +76,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -1067,7 +1073,7 @@ def train_one_epoch(
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -76,7 +76,13 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -1058,7 +1064,7 @@ def train_one_epoch(
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -74,6 +74,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -799,7 +800,7 @@ def train_one_epoch(
|
||||
num_samples += batch_size
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -73,6 +73,8 @@ def compute_fbank_baker_zh(num_jobs: int):
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
if not torch.cuda.is_available():
|
||||
config.device = "cpu"
|
||||
|
||||
prefix = "baker_zh"
|
||||
suffix = "jsonl.gz"
|
||||
|
@ -88,6 +88,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -825,7 +826,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -90,6 +90,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -895,7 +896,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -840,7 +846,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -969,7 +970,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -604,7 +605,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -838,7 +844,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -77,7 +77,13 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -675,7 +681,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -958,7 +959,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1317,7 +1318,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -961,7 +962,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -805,7 +811,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -92,6 +92,7 @@ from icefall.utils import (
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -942,7 +943,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -667,7 +667,9 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(
|
||||
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||
)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -707,7 +709,9 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(
|
||||
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||
)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in [
|
||||
|
@ -271,7 +271,7 @@ def main():
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -351,7 +351,9 @@ def main():
|
||||
"attention-decoder",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
@ -362,7 +364,9 @@ def main():
|
||||
"attention-decoder",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = k2.Fsa.from_dict(
|
||||
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||
)
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = G.to(device)
|
||||
|
@ -774,7 +774,9 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(
|
||||
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||
)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -814,7 +816,9 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(
|
||||
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||
)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in [
|
||||
|
@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -84,9 +83,11 @@ from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -420,7 +421,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -629,7 +630,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -676,7 +677,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -965,7 +966,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -868,7 +868,9 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(
|
||||
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||
)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -907,7 +909,9 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(
|
||||
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||
)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method == "whole-lattice-rescoring":
|
||||
|
@ -334,7 +334,9 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
@ -345,7 +347,9 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = k2.Fsa.from_dict(
|
||||
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||
)
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
|
@ -290,7 +290,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -386,7 +386,9 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
@ -397,7 +399,9 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = k2.Fsa.from_dict(
|
||||
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||
)
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
|
@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import CTCModel
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -95,9 +94,11 @@ from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -493,7 +494,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -694,7 +695,7 @@ def train_one_epoch(
|
||||
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -743,7 +744,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1004,7 +1005,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -574,7 +574,9 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
torch.load(
|
||||
f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False
|
||||
)
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
@ -609,7 +611,9 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
d = torch.load(
|
||||
params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False
|
||||
)
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -93,7 +92,14 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -560,7 +566,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -727,7 +733,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -772,7 +778,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1002,7 +1008,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -93,7 +92,14 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -560,7 +566,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -727,7 +733,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -772,7 +778,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1001,7 +1007,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
|
||||
|
||||
if Path(f"data/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"data/lm/{lm}.pt")
|
||||
d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
|
@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
An FSA representing LG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
|
||||
|
||||
if Path(f"data/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"data/lm/{lm}.pt")
|
||||
d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
|
@ -750,7 +750,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -156,7 +156,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -192,7 +192,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -238,7 +238,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -66,7 +66,6 @@ from lstm import RNN
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -82,9 +81,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -521,7 +522,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -717,7 +718,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -763,7 +764,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1023,7 +1024,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -935,7 +935,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -195,7 +195,7 @@ class Transducer(nn.Module):
|
||||
lm = simple_lm_proj(decoder_out)
|
||||
am = simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -231,7 +231,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -241,7 +241,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -74,7 +74,6 @@ from lstm import RNN
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -90,9 +89,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -560,7 +561,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -772,7 +773,7 @@ def train_one_epoch(
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -848,7 +849,7 @@ def train_one_epoch(
|
||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1176,7 +1177,7 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
logging.info("Skip scan_pessimistic_batches_for_oom")
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -815,7 +815,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -239,7 +239,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -66,7 +66,6 @@ from lstm import RNN
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -82,9 +81,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -551,7 +552,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -747,7 +748,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -793,7 +794,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1067,7 +1068,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -141,7 +141,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -176,7 +176,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -10,9 +10,11 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
from scaling import ScaledLinear
|
||||
from torch import Tensor, nn
|
||||
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch_scheduled_sampling import sample_combined
|
||||
|
||||
from icefall.utils import create_grad_scaler, torch_autocast
|
||||
|
||||
# The main exports of this file are the module KnowledgeBaseLookup and the
|
||||
# function create_knowledge_base.
|
||||
|
||||
@ -330,14 +332,14 @@ def _test_knowledge_base_lookup_autocast():
|
||||
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
|
||||
m = m.to(device)
|
||||
|
||||
scaler = GradScaler(enabled=True)
|
||||
scaler = create_grad_scaler(enabled=True)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
||||
for epoch in range(150):
|
||||
for n, (x, y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch_autocast(enabled=True):
|
||||
loss = ((y_out - y) ** 2).mean() * 100.0
|
||||
if n % 10 == 0 and epoch % 10 == 0:
|
||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||
|
@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -76,7 +75,14 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
create_grad_scaler,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -453,7 +459,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -608,7 +614,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -650,7 +656,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -868,7 +874,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from noam import Noam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -68,7 +67,14 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
@ -496,7 +502,7 @@ def save_checkpoint(
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, and training stats to file.
|
||||
@ -650,7 +656,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -693,7 +699,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -939,7 +945,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -741,7 +741,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -754,7 +754,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -157,7 +157,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -193,7 +193,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -265,7 +265,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -91,9 +90,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -523,7 +524,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -716,7 +717,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -759,7 +760,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1000,7 +1001,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -921,7 +921,7 @@ def load_ngram_LM(
|
||||
|
||||
if pt_file.is_file():
|
||||
logging.info(f"Loading pre-compiled {pt_file}")
|
||||
d = torch.load(pt_file, map_location=device)
|
||||
d = torch.load(pt_file, map_location=device, weights_only=False)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
@ -1101,7 +1101,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
elif params.decoding_method in [
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -195,7 +195,7 @@ class Transducer(nn.Module):
|
||||
lm = simple_lm_proj(decoder_out)
|
||||
am = simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -231,7 +231,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -274,7 +274,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -74,7 +74,6 @@ from librispeech import LibriSpeech
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -87,9 +86,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -546,7 +547,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -755,7 +756,7 @@ def train_one_epoch(
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -827,7 +828,7 @@ def train_one_epoch(
|
||||
|
||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1126,7 +1127,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -913,7 +913,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -96,9 +95,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -548,7 +549,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -744,7 +745,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -789,7 +790,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1047,7 +1048,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -972,7 +972,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -238,7 +238,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -84,9 +83,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -571,7 +572,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -768,7 +769,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -814,7 +815,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1078,7 +1079,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -220,7 +220,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -96,9 +95,11 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
display_and_save_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -519,7 +520,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -736,7 +737,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -781,7 +782,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1039,7 +1040,7 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -348,7 +348,9 @@ class CodebookIndexExtractor:
|
||||
num_codebooks=self.params.num_codebooks,
|
||||
codebook_size=256,
|
||||
)
|
||||
quantizer.load_state_dict(torch.load(self.quantizer_file_path))
|
||||
quantizer.load_state_dict(
|
||||
torch.load(self.quantizer_file_path, weights_only=False)
|
||||
)
|
||||
quantizer.to(self.params.device)
|
||||
return quantizer
|
||||
|
||||
|
@ -289,7 +289,7 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -910,7 +910,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -813,7 +813,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -85,9 +84,11 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -635,7 +636,7 @@ def load_model_params(
|
||||
|
||||
"""
|
||||
logging.info(f"Loading checkpoint from {ckpt}")
|
||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
|
||||
|
||||
# if module list is empty, load the whole model from ckpt
|
||||
if not init_modules:
|
||||
@ -678,7 +679,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -857,7 +858,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -903,7 +904,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1219,7 +1220,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -150,7 +150,7 @@ class Transducer(nn.Module):
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -247,7 +247,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -28,6 +28,8 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import Embedding as ScaledEmbedding
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
(ans,) = ctx.saved_tensors
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
def backward(ctx, x_grad: Tensor):
|
||||
(x_orig,) = ctx.saved_tensors
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
|
||||
@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module):
|
||||
):
|
||||
return _no_op(x)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
eps = 1.0e-20
|
||||
orig_x = x
|
||||
x = x.to(torch.float32)
|
||||
|
@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -86,10 +85,12 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
symlink_or_copy,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -581,7 +582,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -763,7 +764,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -809,7 +810,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1106,7 +1107,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -44,7 +44,7 @@ from scaling import (
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.dist import get_rank
|
||||
from icefall.utils import is_jit_tracing, make_pad_mask
|
||||
from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast
|
||||
|
||||
|
||||
class Zipformer(EncoderInterface):
|
||||
@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz = n // num_heads
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_output = attn_output.to(torch.float32)
|
||||
attn_weights_entropy = (
|
||||
|
@ -633,7 +633,9 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(
|
||||
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||
)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -672,7 +674,9 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(
|
||||
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||
)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method == "whole-lattice-rescoring":
|
||||
|
@ -786,7 +786,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
@ -347,7 +347,9 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
@ -358,7 +360,9 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = k2.Fsa.from_dict(
|
||||
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||
)
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, torch_autocast
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -150,7 +150,7 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch_autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
|
@ -247,7 +247,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -286,7 +286,7 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -365,7 +365,9 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
@ -376,7 +378,9 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = k2.Fsa.from_dict(
|
||||
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||
)
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
|
@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
create_grad_scaler,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
torch_autocast,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -588,7 +589,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
scaler: Optional["GradScaler"] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -787,7 +788,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
scaler: "GradScaler",
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -833,7 +834,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -1128,7 +1129,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
with torch_autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
|
@ -624,7 +624,9 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
torch.load(
|
||||
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||
)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -663,7 +665,9 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
d = torch.load(
|
||||
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||
)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method == "whole-lattice-rescoring":
|
||||
|
@ -808,7 +808,7 @@ def main():
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user