mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix CI tests. (#1974)
- Introduce unified AMP helpers (create_grad_scaler, torch_autocast) to handle deprecations in PyTorch ≥2.3.0 - Replace direct uses of torch.cuda.amp.GradScaler and torch.cuda.amp.autocast with the new utilities across all training and inference scripts - Update all torch.load calls to include weights_only=False for compatibility with newer PyTorch versions
This commit is contained in:
parent
71377d21cd
commit
fba5e67d5e
6
.github/workflows/aishell.yml
vendored
6
.github/workflows/aishell.yml
vendored
@ -17,7 +17,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
generate_build_matrix:
|
generate_build_matrix:
|
||||||
if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell')
|
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||||
|
|
||||||
# see https://github.com/pytorch/pytorch/pull/50633
|
# see https://github.com/pytorch/pytorch/pull/50633
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@ -31,8 +31,8 @@ jobs:
|
|||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
# outputting for debugging purposes
|
# outputting for debugging purposes
|
||||||
python ./.github/scripts/docker/generate_build_matrix.py
|
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||||
echo "::set-output name=matrix::${MATRIX}"
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
aishell:
|
aishell:
|
||||||
needs: generate_build_matrix
|
needs: generate_build_matrix
|
||||||
|
10
.github/workflows/audioset.yml
vendored
10
.github/workflows/audioset.yml
vendored
@ -30,8 +30,8 @@ jobs:
|
|||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
# outputting for debugging purposes
|
# outputting for debugging purposes
|
||||||
python ./.github/scripts/docker/generate_build_matrix.py
|
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||||
echo "::set-output name=matrix::${MATRIX}"
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
|
|
||||||
audioset:
|
audioset:
|
||||||
@ -83,7 +83,7 @@ jobs:
|
|||||||
ls -lh ./model-onnx/*
|
ls -lh ./model-onnx/*
|
||||||
|
|
||||||
- name: Upload model to huggingface
|
- name: Upload model to huggingface
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
uses: nick-fields/retry@v3
|
uses: nick-fields/retry@v3
|
||||||
@ -116,7 +116,7 @@ jobs:
|
|||||||
rm -rf huggingface
|
rm -rf huggingface
|
||||||
|
|
||||||
- name: Prepare for release
|
- name: Prepare for release
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||||
@ -125,7 +125,7 @@ jobs:
|
|||||||
ls -lh
|
ls -lh
|
||||||
|
|
||||||
- name: Release exported onnx models
|
- name: Release exported onnx models
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||||
uses: svenstaro/upload-release-action@v2
|
uses: svenstaro/upload-release-action@v2
|
||||||
with:
|
with:
|
||||||
file_glob: true
|
file_glob: true
|
||||||
|
20
.github/workflows/baker_zh.yml
vendored
20
.github/workflows/baker_zh.yml
vendored
@ -31,8 +31,8 @@ jobs:
|
|||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
# outputting for debugging purposes
|
# outputting for debugging purposes
|
||||||
python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3"
|
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3")
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||||
echo "::set-output name=matrix::${MATRIX}"
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
|
|
||||||
baker_zh:
|
baker_zh:
|
||||||
@ -84,43 +84,43 @@ jobs:
|
|||||||
ls -lh
|
ls -lh
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||||
path: ./*.wav
|
path: ./*.wav
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-2
|
name: step-2
|
||||||
path: ./model-steps-2.onnx
|
path: ./model-steps-2.onnx
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-3
|
name: step-3
|
||||||
path: ./model-steps-3.onnx
|
path: ./model-steps-3.onnx
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-4
|
name: step-4
|
||||||
path: ./model-steps-4.onnx
|
path: ./model-steps-4.onnx
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-5
|
name: step-5
|
||||||
path: ./model-steps-5.onnx
|
path: ./model-steps-5.onnx
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-6
|
name: step-6
|
||||||
path: ./model-steps-6.onnx
|
path: ./model-steps-6.onnx
|
||||||
|
|
||||||
- name: Upload models to huggingface
|
- name: Upload models to huggingface
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
@ -141,7 +141,7 @@ jobs:
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
- name: Release exported onnx models
|
- name: Release exported onnx models
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||||
uses: svenstaro/upload-release-action@v2
|
uses: svenstaro/upload-release-action@v2
|
||||||
with:
|
with:
|
||||||
file_glob: true
|
file_glob: true
|
||||||
|
5
.github/workflows/librispeech.yml
vendored
5
.github/workflows/librispeech.yml
vendored
@ -29,8 +29,9 @@ jobs:
|
|||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
# outputting for debugging purposes
|
# outputting for debugging purposes
|
||||||
python ./.github/scripts/docker/generate_build_matrix.py
|
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||||
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.6.0")
|
||||||
echo "::set-output name=matrix::${MATRIX}"
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
librispeech:
|
librispeech:
|
||||||
needs: generate_build_matrix
|
needs: generate_build_matrix
|
||||||
|
22
.github/workflows/ljspeech.yml
vendored
22
.github/workflows/ljspeech.yml
vendored
@ -30,8 +30,8 @@ jobs:
|
|||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
# outputting for debugging purposes
|
# outputting for debugging purposes
|
||||||
python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3"
|
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3")
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||||
echo "::set-output name=matrix::${MATRIX}"
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
|
|
||||||
ljspeech:
|
ljspeech:
|
||||||
@ -83,13 +83,13 @@ jobs:
|
|||||||
ls -lh
|
ls -lh
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||||
path: ./*.wav
|
path: ./*.wav
|
||||||
|
|
||||||
- name: Release exported onnx models
|
- name: Release exported onnx models
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||||
uses: svenstaro/upload-release-action@v2
|
uses: svenstaro/upload-release-action@v2
|
||||||
with:
|
with:
|
||||||
file_glob: true
|
file_glob: true
|
||||||
@ -100,37 +100,37 @@ jobs:
|
|||||||
tag: tts-models
|
tag: tts-models
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-2
|
name: step-2
|
||||||
path: ./model-steps-2.onnx
|
path: ./model-steps-2.onnx
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-3
|
name: step-3
|
||||||
path: ./model-steps-3.onnx
|
path: ./model-steps-3.onnx
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-4
|
name: step-4
|
||||||
path: ./model-steps-4.onnx
|
path: ./model-steps-4.onnx
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-5
|
name: step-5
|
||||||
path: ./model-steps-5.onnx
|
path: ./model-steps-5.onnx
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
with:
|
with:
|
||||||
name: step-6
|
name: step-6
|
||||||
path: ./model-steps-6.onnx
|
path: ./model-steps-6.onnx
|
||||||
|
|
||||||
- name: Upload models to huggingface
|
- name: Upload models to huggingface
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
@ -155,7 +155,7 @@ jobs:
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
- name: Release exported onnx models
|
- name: Release exported onnx models
|
||||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
if: matrix.python-version == '3.10' && matrix.torch-version == '2.3.0'
|
||||||
uses: svenstaro/upload-release-action@v2
|
uses: svenstaro/upload-release-action@v2
|
||||||
with:
|
with:
|
||||||
file_glob: true
|
file_glob: true
|
||||||
|
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@ -30,8 +30,8 @@ jobs:
|
|||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
# outputting for debugging purposes
|
# outputting for debugging purposes
|
||||||
python ./.github/scripts/docker/generate_build_matrix.py
|
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||||
echo "::set-output name=matrix::${MATRIX}"
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
test:
|
test:
|
||||||
needs: generate_build_matrix
|
needs: generate_build_matrix
|
||||||
|
5
.github/workflows/yesno.yml
vendored
5
.github/workflows/yesno.yml
vendored
@ -30,8 +30,9 @@ jobs:
|
|||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
# outputting for debugging purposes
|
# outputting for debugging purposes
|
||||||
python ./.github/scripts/docker/generate_build_matrix.py
|
python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10"
|
||||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10")
|
||||||
|
# MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0")
|
||||||
echo "::set-output name=matrix::${MATRIX}"
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
yesno:
|
yesno:
|
||||||
needs: generate_build_matrix
|
needs: generate_build_matrix
|
||||||
|
@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -638,7 +644,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -688,7 +694,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -184,7 +184,7 @@ class Transducer(nn.Module):
|
|||||||
lm = simple_lm_proj(decoder_out)
|
lm = simple_lm_proj(decoder_out)
|
||||||
am = simple_am_proj(encoder_out)
|
am = simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -219,7 +219,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -94,7 +94,13 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -797,7 +803,7 @@ def train_one_epoch(
|
|||||||
aishell = is_aishell(batch["supervisions"]["cut"][0])
|
aishell = is_aishell(batch["supervisions"]["cut"][0])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -94,6 +94,7 @@ from icefall.utils import (
|
|||||||
filter_uneven_sized_batch,
|
filter_uneven_sized_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -809,7 +810,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -87,6 +87,7 @@ from icefall.utils import (
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
tokenize_by_CJK_char,
|
tokenize_by_CJK_char,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -802,7 +803,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -81,7 +81,13 @@ from icefall.env import get_env_info
|
|||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -812,7 +818,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -81,6 +81,7 @@ from icefall.utils import (
|
|||||||
filter_uneven_sized_batch,
|
filter_uneven_sized_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -514,7 +515,7 @@ def compute_validation_loss(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -608,7 +609,7 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -95,6 +95,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -910,7 +911,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1302,7 +1303,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -92,6 +92,7 @@ from icefall.utils import (
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
tokenize_by_CJK_char,
|
tokenize_by_CJK_char,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -495,7 +496,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -90,7 +90,13 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -734,7 +740,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -83,7 +83,13 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -727,7 +733,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
# print(batch["supervisions"])
|
# print(batch["supervisions"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -638,7 +644,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -73,7 +73,13 @@ from icefall.env import get_env_info
|
|||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -782,7 +788,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
|||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -773,7 +779,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -76,7 +76,13 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -1067,7 +1073,7 @@ def train_one_epoch(
|
|||||||
batch_size = batch["inputs"].shape[0]
|
batch_size = batch["inputs"].shape[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -76,7 +76,13 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -1058,7 +1064,7 @@ def train_one_epoch(
|
|||||||
batch_size = batch["inputs"].shape[0]
|
batch_size = batch["inputs"].shape[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -74,6 +74,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -799,7 +800,7 @@ def train_one_epoch(
|
|||||||
num_samples += batch_size
|
num_samples += batch_size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -73,6 +73,8 @@ def compute_fbank_baker_zh(num_jobs: int):
|
|||||||
f_min=0,
|
f_min=0,
|
||||||
f_max=8000,
|
f_max=8000,
|
||||||
)
|
)
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
config.device = "cpu"
|
||||||
|
|
||||||
prefix = "baker_zh"
|
prefix = "baker_zh"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
|
@ -88,6 +88,7 @@ from icefall.utils import (
|
|||||||
filter_uneven_sized_batch,
|
filter_uneven_sized_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -825,7 +826,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -90,6 +90,7 @@ from icefall.utils import (
|
|||||||
filter_uneven_sized_batch,
|
filter_uneven_sized_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -895,7 +896,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
|||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -840,7 +846,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -969,7 +970,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -604,7 +605,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
|||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
@ -838,7 +844,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -77,7 +77,13 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -675,7 +681,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -958,7 +959,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1317,7 +1318,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -97,6 +97,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -961,7 +962,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
|||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -805,7 +811,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -92,6 +92,7 @@ from icefall.utils import (
|
|||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -942,7 +943,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -667,7 +667,9 @@ def main():
|
|||||||
H = None
|
H = None
|
||||||
bpe_model = None
|
bpe_model = None
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
torch.load(
|
||||||
|
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -707,7 +709,9 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
d = torch.load(
|
||||||
|
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.method in [
|
if params.method in [
|
||||||
|
@ -271,7 +271,7 @@ def main():
|
|||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -351,7 +351,9 @@ def main():
|
|||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading HLG from {params.HLG}")
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
# For whole-lattice-rescoring and attention-decoder
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
@ -362,7 +364,9 @@ def main():
|
|||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading G from {params.G}")
|
logging.info(f"Loading G from {params.G}")
|
||||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
G = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
# it with the whole lattice later
|
# it with the whole lattice later
|
||||||
G = G.to(device)
|
G = G.to(device)
|
||||||
|
@ -774,7 +774,9 @@ def main():
|
|||||||
H = None
|
H = None
|
||||||
bpe_model = None
|
bpe_model = None
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
torch.load(
|
||||||
|
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -814,7 +816,9 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
d = torch.load(
|
||||||
|
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.method in [
|
if params.method in [
|
||||||
|
@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -84,9 +83,11 @@ from icefall.lexicon import Lexicon
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -420,7 +421,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -629,7 +630,7 @@ def train_one_epoch(
|
|||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -676,7 +677,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -965,7 +966,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -868,7 +868,9 @@ def main():
|
|||||||
H = None
|
H = None
|
||||||
bpe_model = None
|
bpe_model = None
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
torch.load(
|
||||||
|
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -907,7 +909,9 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
d = torch.load(
|
||||||
|
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.decoding_method == "whole-lattice-rescoring":
|
if params.decoding_method == "whole-lattice-rescoring":
|
||||||
|
@ -334,7 +334,9 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading HLG from {params.HLG}")
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
# For whole-lattice-rescoring and attention-decoder
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
@ -345,7 +347,9 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading G from {params.G}")
|
logging.info(f"Loading G from {params.G}")
|
||||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
G = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
G = G.to(device)
|
G = G.to(device)
|
||||||
if params.method == "whole-lattice-rescoring":
|
if params.method == "whole-lattice-rescoring":
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
@ -290,7 +290,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -386,7 +386,9 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading HLG from {params.HLG}")
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
# For whole-lattice-rescoring and attention-decoder
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
@ -397,7 +399,9 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading G from {params.G}")
|
logging.info(f"Loading G from {params.G}")
|
||||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
G = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
G = G.to(device)
|
G = G.to(device)
|
||||||
if params.method == "whole-lattice-rescoring":
|
if params.method == "whole-lattice-rescoring":
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import CTCModel
|
from model import CTCModel
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -95,9 +94,11 @@ from icefall.lexicon import Lexicon
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -493,7 +494,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -694,7 +695,7 @@ def train_one_epoch(
|
|||||||
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
|
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -743,7 +744,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1004,7 +1005,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -574,7 +574,9 @@ def main():
|
|||||||
H = None
|
H = None
|
||||||
bpe_model = None
|
bpe_model = None
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
torch.load(
|
||||||
|
f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
@ -609,7 +611,9 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
d = torch.load(
|
||||||
|
params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False
|
||||||
|
)
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -93,7 +92,14 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -560,7 +566,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -727,7 +733,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -772,7 +778,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1002,7 +1008,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -93,7 +92,14 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -560,7 +566,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -727,7 +733,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -772,7 +778,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1001,7 +1007,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
|||||||
max_token_id = max(lexicon.tokens)
|
max_token_id = max(lexicon.tokens)
|
||||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||||
H = k2.ctc_topo(max_token_id)
|
H = k2.ctc_topo(max_token_id)
|
||||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
|
||||||
|
|
||||||
if Path(f"data/lm/{lm}.pt").is_file():
|
if Path(f"data/lm/{lm}.pt").is_file():
|
||||||
logging.info(f"Loading pre-compiled {lm}")
|
logging.info(f"Loading pre-compiled {lm}")
|
||||||
d = torch.load(f"data/lm/{lm}.pt")
|
d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
else:
|
else:
|
||||||
logging.info(f"Loading {lm}.fst.txt")
|
logging.info(f"Loading {lm}.fst.txt")
|
||||||
|
@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
|||||||
An FSA representing LG.
|
An FSA representing LG.
|
||||||
"""
|
"""
|
||||||
lexicon = Lexicon(lang_dir)
|
lexicon = Lexicon(lang_dir)
|
||||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
|
||||||
|
|
||||||
if Path(f"data/lm/{lm}.pt").is_file():
|
if Path(f"data/lm/{lm}.pt").is_file():
|
||||||
logging.info(f"Loading pre-compiled {lm}")
|
logging.info(f"Loading pre-compiled {lm}")
|
||||||
d = torch.load(f"data/lm/{lm}.pt")
|
d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
else:
|
else:
|
||||||
logging.info(f"Loading {lm}.fst.txt")
|
logging.info(f"Loading {lm}.fst.txt")
|
||||||
|
@ -750,7 +750,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -156,7 +156,7 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -192,7 +192,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -238,7 +238,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -66,7 +66,6 @@ from lstm import RNN
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -82,9 +81,11 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -521,7 +522,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -717,7 +718,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -763,7 +764,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1023,7 +1024,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -935,7 +935,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -195,7 +195,7 @@ class Transducer(nn.Module):
|
|||||||
lm = simple_lm_proj(decoder_out)
|
lm = simple_lm_proj(decoder_out)
|
||||||
am = simple_am_proj(encoder_out)
|
am = simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -231,7 +231,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -241,7 +241,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -74,7 +74,6 @@ from lstm import RNN
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -90,9 +89,11 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -560,7 +561,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -772,7 +773,7 @@ def train_one_epoch(
|
|||||||
giga_train_dl: torch.utils.data.DataLoader,
|
giga_train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
rng: random.Random,
|
rng: random.Random,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -848,7 +849,7 @@ def train_one_epoch(
|
|||||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1176,7 +1177,7 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
logging.info("Skip scan_pessimistic_batches_for_oom")
|
logging.info("Skip scan_pessimistic_batches_for_oom")
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -815,7 +815,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -239,7 +239,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -66,7 +66,6 @@ from lstm import RNN
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -82,9 +81,11 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -551,7 +552,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -747,7 +748,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -793,7 +794,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1067,7 +1068,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -21,7 +21,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -141,7 +141,7 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -176,7 +176,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -10,9 +10,11 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
from torch_scheduled_sampling import sample_combined
|
from torch_scheduled_sampling import sample_combined
|
||||||
|
|
||||||
|
from icefall.utils import create_grad_scaler, torch_autocast
|
||||||
|
|
||||||
# The main exports of this file are the module KnowledgeBaseLookup and the
|
# The main exports of this file are the module KnowledgeBaseLookup and the
|
||||||
# function create_knowledge_base.
|
# function create_knowledge_base.
|
||||||
|
|
||||||
@ -330,14 +332,14 @@ def _test_knowledge_base_lookup_autocast():
|
|||||||
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
|
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
|
||||||
m = m.to(device)
|
m = m.to(device)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=True)
|
scaler = create_grad_scaler(enabled=True)
|
||||||
|
|
||||||
start = timeit.default_timer()
|
start = timeit.default_timer()
|
||||||
|
|
||||||
for epoch in range(150):
|
for epoch in range(150):
|
||||||
for n, (x, y) in enumerate(train_pairs):
|
for n, (x, y) in enumerate(train_pairs):
|
||||||
y_out = m(x)
|
y_out = m(x)
|
||||||
with torch.cuda.amp.autocast(enabled=True):
|
with torch_autocast(enabled=True):
|
||||||
loss = ((y_out - y) ** 2).mean() * 100.0
|
loss = ((y_out - y) ** 2).mean() * 100.0
|
||||||
if n % 10 == 0 and epoch % 10 == 0:
|
if n % 10 == 0 and epoch % 10 == 0:
|
||||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||||
|
@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -76,7 +75,14 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|||||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
create_grad_scaler,
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -453,7 +459,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -608,7 +614,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -650,7 +656,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -868,7 +874,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from noam import Noam
|
from noam import Noam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -68,7 +67,14 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
@ -496,7 +502,7 @@ def save_checkpoint(
|
|||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, and training stats to file.
|
"""Save model, optimizer, and training stats to file.
|
||||||
@ -650,7 +656,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -693,7 +699,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -939,7 +945,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -741,7 +741,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -754,7 +754,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -157,7 +157,7 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -193,7 +193,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -265,7 +265,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -91,9 +90,11 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -523,7 +524,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -716,7 +717,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -759,7 +760,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1000,7 +1001,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -921,7 +921,7 @@ def load_ngram_LM(
|
|||||||
|
|
||||||
if pt_file.is_file():
|
if pt_file.is_file():
|
||||||
logging.info(f"Loading pre-compiled {pt_file}")
|
logging.info(f"Loading pre-compiled {pt_file}")
|
||||||
d = torch.load(pt_file, map_location=device)
|
d = torch.load(pt_file, map_location=device, weights_only=False)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
G = k2.add_epsilon_self_loops(G)
|
G = k2.add_epsilon_self_loops(G)
|
||||||
G = k2.arc_sort(G)
|
G = k2.arc_sort(G)
|
||||||
@ -1101,7 +1101,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
elif params.decoding_method in [
|
elif params.decoding_method in [
|
||||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -195,7 +195,7 @@ class Transducer(nn.Module):
|
|||||||
lm = simple_lm_proj(decoder_out)
|
lm = simple_lm_proj(decoder_out)
|
||||||
am = simple_am_proj(encoder_out)
|
am = simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -231,7 +231,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
logits = joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -274,7 +274,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -74,7 +74,6 @@ from librispeech import LibriSpeech
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -87,9 +86,11 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -546,7 +547,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -755,7 +756,7 @@ def train_one_epoch(
|
|||||||
giga_train_dl: torch.utils.data.DataLoader,
|
giga_train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
rng: random.Random,
|
rng: random.Random,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -827,7 +828,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1126,7 +1127,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -913,7 +913,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -96,9 +95,11 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -548,7 +549,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -744,7 +745,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -789,7 +790,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1047,7 +1048,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -972,7 +972,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -238,7 +238,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -84,9 +83,11 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -571,7 +572,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -768,7 +769,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -814,7 +815,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1078,7 +1079,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -220,7 +220,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -96,9 +95,11 @@ from icefall.env import get_env_info
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
display_and_save_batch,
|
display_and_save_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -519,7 +520,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -736,7 +737,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -781,7 +782,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1039,7 +1040,7 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -348,7 +348,9 @@ class CodebookIndexExtractor:
|
|||||||
num_codebooks=self.params.num_codebooks,
|
num_codebooks=self.params.num_codebooks,
|
||||||
codebook_size=256,
|
codebook_size=256,
|
||||||
)
|
)
|
||||||
quantizer.load_state_dict(torch.load(self.quantizer_file_path))
|
quantizer.load_state_dict(
|
||||||
|
torch.load(self.quantizer_file_path, weights_only=False)
|
||||||
|
)
|
||||||
quantizer.to(self.params.device)
|
quantizer.to(self.params.device)
|
||||||
return quantizer
|
return quantizer
|
||||||
|
|
||||||
|
@ -289,7 +289,7 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -910,7 +910,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -813,7 +813,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer
|
||||||
@ -85,9 +84,11 @@ from icefall.hooks import register_inf_check_hooks
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
filter_uneven_sized_batch,
|
filter_uneven_sized_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -635,7 +636,7 @@ def load_model_params(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
logging.info(f"Loading checkpoint from {ckpt}")
|
logging.info(f"Loading checkpoint from {ckpt}")
|
||||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
# if module list is empty, load the whole model from ckpt
|
# if module list is empty, load the whole model from ckpt
|
||||||
if not init_modules:
|
if not init_modules:
|
||||||
@ -678,7 +679,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -857,7 +858,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -903,7 +904,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1219,7 +1220,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import penalize_abs_values_gt
|
from scaling import penalize_abs_values_gt
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -150,7 +150,7 @@ class Transducer(nn.Module):
|
|||||||
# if self.training and random.random() < 0.25:
|
# if self.training and random.random() < 0.25:
|
||||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -247,7 +247,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -28,6 +28,8 @@ import torch.nn.functional as F
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Embedding as ScaledEmbedding
|
from torch.nn import Embedding as ScaledEmbedding
|
||||||
|
|
||||||
|
from icefall.utils import torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancerFunction(torch.autograd.Function):
|
class ActivationBalancerFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor):
|
def backward(ctx, ans_grad: Tensor):
|
||||||
(ans,) = ctx.saved_tensors
|
(ans,) = ctx.saved_tensors
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
ans_grad = ans_grad.to(torch.float32)
|
ans_grad = ans_grad.to(torch.float32)
|
||||||
ans = ans.to(torch.float32)
|
ans = ans.to(torch.float32)
|
||||||
x_grad = ans_grad * ans
|
x_grad = ans_grad * ans
|
||||||
@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, x_grad: Tensor):
|
def backward(ctx, x_grad: Tensor):
|
||||||
(x_orig,) = ctx.saved_tensors
|
(x_orig,) = ctx.saved_tensors
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
x_detached = x_orig.to(torch.float32).detach()
|
x_detached = x_orig.to(torch.float32).detach()
|
||||||
x_detached.requires_grad = True
|
x_detached.requires_grad = True
|
||||||
|
|
||||||
@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
eps = 1.0e-20
|
eps = 1.0e-20
|
||||||
orig_x = x
|
orig_x = x
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
|
@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer
|
||||||
@ -86,10 +85,12 @@ from icefall.hooks import register_inf_check_hooks
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
filter_uneven_sized_batch,
|
filter_uneven_sized_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
symlink_or_copy,
|
symlink_or_copy,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -581,7 +582,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -763,7 +764,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -809,7 +810,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1106,7 +1107,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -44,7 +44,7 @@ from scaling import (
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.dist import get_rank
|
from icefall.dist import get_rank
|
||||||
from icefall.utils import is_jit_tracing, make_pad_mask
|
from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Zipformer(EncoderInterface):
|
class Zipformer(EncoderInterface):
|
||||||
@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
bsz = n // num_heads
|
bsz = n // num_heads
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
attn_weights = attn_weights.to(torch.float32)
|
attn_weights = attn_weights.to(torch.float32)
|
||||||
attn_output = attn_output.to(torch.float32)
|
attn_output = attn_output.to(torch.float32)
|
||||||
attn_weights_entropy = (
|
attn_weights_entropy = (
|
||||||
|
@ -633,7 +633,9 @@ def main():
|
|||||||
H = None
|
H = None
|
||||||
bpe_model = None
|
bpe_model = None
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
torch.load(
|
||||||
|
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -672,7 +674,9 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
d = torch.load(
|
||||||
|
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.decoding_method == "whole-lattice-rescoring":
|
if params.decoding_method == "whole-lattice-rescoring":
|
||||||
|
@ -786,7 +786,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
@ -347,7 +347,9 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading HLG from {params.HLG}")
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
# For whole-lattice-rescoring and attention-decoder
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
@ -358,7 +360,9 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading G from {params.G}")
|
logging.info(f"Loading G from {params.G}")
|
||||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
G = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
G = G.to(device)
|
G = G.to(device)
|
||||||
if params.method == "whole-lattice-rescoring":
|
if params.method == "whole-lattice-rescoring":
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
@ -22,7 +22,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -150,7 +150,7 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
am=am.float(),
|
am=am.float(),
|
||||||
@ -185,7 +185,7 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch_autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
|
@ -247,7 +247,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -286,7 +286,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -365,7 +365,9 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading HLG from {params.HLG}")
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.HLG, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
# For whole-lattice-rescoring and attention-decoder
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
@ -376,7 +378,9 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading G from {params.G}")
|
logging.info(f"Loading G from {params.G}")
|
||||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
G = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.G, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
G = G.to(device)
|
G = G.to(device)
|
||||||
if params.method == "whole-lattice-rescoring":
|
if params.method == "whole-lattice-rescoring":
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer
|
||||||
@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
create_grad_scaler,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
torch_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -588,7 +589,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional["GradScaler"] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -787,7 +788,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: "GradScaler",
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -833,7 +834,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1128,7 +1129,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch_autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -624,7 +624,9 @@ def main():
|
|||||||
H = None
|
H = None
|
||||||
bpe_model = None
|
bpe_model = None
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
torch.load(
|
||||||
|
f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -663,7 +665,9 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
d = torch.load(
|
||||||
|
params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
|
||||||
|
)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.decoding_method == "whole-lattice-rescoring":
|
if params.decoding_method == "whole-lattice-rescoring":
|
||||||
|
@ -808,7 +808,7 @@ def main():
|
|||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
logging.info(f"Loading {lg_filename}")
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(lg_filename, map_location=device)
|
torch.load(lg_filename, map_location=device, weights_only=False)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user