mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 22:15:28 +00:00
Add Zipformer from Dan (#672)
This commit is contained in:
parent
e334e570d8
commit
7e82f87126
106
.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
vendored
Executable file
106
.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
vendored
Executable file
@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/cpu_jit.pt"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
ls -lh *.pt
|
||||
popd
|
||||
|
||||
log "Export to torchscript model"
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with models exported by torch.jit.script()"
|
||||
|
||||
./pruned_transducer_stateless7/jit_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--nn-model-filename $repo/exp/cpu_jit.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search fast_beam_search; do
|
||||
log "$method"
|
||||
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||
mkdir -p pruned_transducer_stateless7/exp
|
||||
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7/exp/epoch-999.pt
|
||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||
|
||||
ls -lh data
|
||||
ls -lh pruned_transducer_stateless7/exp
|
||||
|
||||
log "Decoding test-clean and test-other"
|
||||
|
||||
# use a small value for decoding with CPU
|
||||
max_duration=100
|
||||
|
||||
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||
log "Decoding with $method"
|
||||
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--decoding-method $method \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
--max-duration $max_duration \
|
||||
--exp-dir pruned_transducer_stateless7/exp
|
||||
done
|
||||
|
||||
rm pruned_transducer_stateless7/exp/*.pt
|
||||
fi
|
||||
155
.github/workflows/run-librispeech-2022-11-11-stateless7.yml
vendored
Normal file
155
.github/workflows/run-librispeech-2022-11-11-stateless7.yml
vendored
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-librispeech-2022-11-11-stateless7
|
||||
# zipformer
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
schedule:
|
||||
# minute (0-59)
|
||||
# hour (0-23)
|
||||
# day of the month (1-31)
|
||||
# month (1-12)
|
||||
# day of the week (0-6)
|
||||
# nightly build at 15:50 UTC time every day
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
jobs:
|
||||
run_librispeech_2022_11_11_zipformer:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||
id: libri-test-clean-and-test-other-data
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/download
|
||||
key: cache-libri-test-clean-and-test-other
|
||||
|
||||
- name: Download LibriSpeech test-clean and test-other
|
||||
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||
|
||||
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||
|
||||
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||
id: libri-test-clean-and-test-other-fbank
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/fbank-libri
|
||||
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||
|
||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
mkdir -p egs/librispeech/ASR/data
|
||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||
ls -lh egs/librispeech/ASR/data/*
|
||||
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
|
||||
|
||||
- name: Display decoding results for librispeech pruned_transducer_stateless7
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
shell: bash
|
||||
run: |
|
||||
cd egs/librispeech/ASR/
|
||||
tree ./pruned_transducer_stateless7/exp
|
||||
|
||||
cd pruned_transducer_stateless7
|
||||
echo "results for pruned_transducer_stateless7"
|
||||
echo "===greedy search==="
|
||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===fast_beam_search==="
|
||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===modified beam search==="
|
||||
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
- name: Upload decoding results for librispeech pruned_transducer_stateless7
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-2022-11-11
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/
|
||||
@ -22,6 +22,7 @@ The following table lists the differences among them.
|
||||
| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training |
|
||||
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
|
||||
| `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert|
|
||||
| `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan|
|
||||
| `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
|
||||
| `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
|
||||
| `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model |
|
||||
|
||||
@ -1,5 +1,63 @@
|
||||
## Results
|
||||
|
||||
### pruned_transducer_stateless7 (zipformer)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/672> for more details.
|
||||
|
||||
[pruned_transducer_stateless7](./pruned_transducer_stateless7)
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/P7vXWqK7QVu1mU9Ene1gGg/>
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding
|
||||
results at:
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
Number of model parameters: 70369391, i.e., 70.37 M
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|----------------------|------------|-------------|----------------------------------------|
|
||||
| greedy search | 2.17 | 5.23 | --epoch 39 --avg 6 --max-duration 600 |
|
||||
| modified beam search | 2.15 | 5.20 | --epoch 39 --avg 6 --max-duration 600 |
|
||||
| fast beam search | 2.15 | 5.22 | --epoch 39 --avg 6 --max-duration 600 |
|
||||
|
||||
The training commands are:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,3,6,7"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--full-libri 1 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 750 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--master-port 12535
|
||||
```
|
||||
|
||||
The decoding commands are:
|
||||
```bash
|
||||
for m in greedy_search fast_beam_search modified_beam_search ; do
|
||||
for epoch in 30; do
|
||||
for avg in 9; do
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--max-duration 600 \
|
||||
--decoding-method $m
|
||||
done
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
|
||||
|
||||
#### [lstm_transducer_stateless3](./lstm_transducer_stateless3)
|
||||
|
||||
0
egs/librispeech/ASR/pruned2_knowledge/__init__.py
Normal file
0
egs/librispeech/ASR/pruned2_knowledge/__init__.py
Normal file
428
egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
Normal file
428
egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
Normal file
@ -0,0 +1,428 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use 960h LibriSpeech. "
|
||||
"Otherwise, use 100h subset.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the BucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||
)
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using BucketingSampler.")
|
||||
train_sampler = BucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
bucket_method="equal_duration",
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = BucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = BucketingSampler(
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest(
|
||||
self.args.manifest_dir / "cuts_train-clean-100.json.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest(
|
||||
self.args.manifest_dir / "cuts_train-clean-360.json.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest(
|
||||
self.args.manifest_dir / "cuts_train-other-500.json.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz")
|
||||
766
egs/librispeech/ASR/pruned2_knowledge/beam_search.py
Normal file
766
egs/librispeech/ASR/pruned2_knowledge/beam_search.py
Normal file
@ -0,0 +1,766 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
model: Transducer,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
decoding_graph:
|
||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder.
|
||||
encoder_out_lens:
|
||||
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||
before padding.
|
||||
beam:
|
||||
Beam value, similar to the beam used in Kaldi..
|
||||
max_states:
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
vocab_size = model.decoder.vocab_size
|
||||
|
||||
B, T, C = encoder_out.shape
|
||||
|
||||
config = k2.RnntDecodingConfig(
|
||||
vocab_size=vocab_size,
|
||||
decoder_history_len=context_size,
|
||||
beam=beam,
|
||||
max_contexts=max_contexts,
|
||||
max_states=max_states,
|
||||
)
|
||||
individual_streams = []
|
||||
for i in range(B):
|
||||
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
|
||||
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
for t in range(T):
|
||||
# shape is a RaggedShape of shape (B, context)
|
||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||
shape, contexts = decoding_streams.get_contexts()
|
||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||
contexts = contexts.to(torch.int64)
|
||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(contexts, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# current_encoder_out is of shape
|
||||
# (shape.NumElements(), 1, joiner_dim)
|
||||
# fmt: off
|
||||
current_encoder_out = torch.index_select(
|
||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||
)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
decoding_streams.advance(log_probs)
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||
) -> List[int]:
|
||||
"""Greedy search for a single utterance.
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
max_sym_per_frame:
|
||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||
would be 100%.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
hyp = [blank_id] * context_size
|
||||
|
||||
# Maximum symbols per utterance.
|
||||
max_sym_per_utt = 1000
|
||||
|
||||
# symbols per frame
|
||||
sym_per_frame = 0
|
||||
|
||||
# symbols per utterance decoded so far
|
||||
sym_per_utt = 0
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
if sym_per_frame >= max_sym_per_frame:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = torch.tensor(
|
||||
[hyp[-context_size:]], device=device
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
sym_per_utt += 1
|
||||
sym_per_frame += 1
|
||||
else:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
hyp = hyp[context_size:] # remove blanks
|
||||
|
||||
return hyp
|
||||
|
||||
|
||||
def greedy_search_batch(
|
||||
model: Transducer, encoder_out: torch.Tensor
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs containing the decoded results.
|
||||
len(ans) equals to encoder_out.size(0).
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
device = model.device
|
||||
|
||||
batch_size = encoder_out.size(0)
|
||||
T = encoder_out.size(1)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (batch_size, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
ans = [h[context_size:] for h in hyps]
|
||||
return ans
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
# The predicted tokens so far.
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
return "_".join(map(str, self.ys))
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
data:
|
||||
A dict of Hypotheses. Its key is its `value.key`.
|
||||
"""
|
||||
if data is None:
|
||||
self._data = {}
|
||||
else:
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
`log-sum-exp` with the existed one.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be added.
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||
"""Get the most probable hypothesis, i.e., the one with
|
||||
the largest `log_prob`.
|
||||
|
||||
Args:
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(
|
||||
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
|
||||
)
|
||||
else:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
|
||||
Caution:
|
||||
`self` is modified **in-place**.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be removed from `self`.
|
||||
Note: It must be contained in `self`. Otherwise,
|
||||
an exception is raised.
|
||||
"""
|
||||
key = hyp.key
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
|
||||
def topk(self, k: int) -> "HypothesisList":
|
||||
"""Return the top-k hypothesis."""
|
||||
hyps = list(self._data.items())
|
||||
|
||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||
|
||||
ans = HypothesisList(dict(hyps))
|
||||
return ans
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self._data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = []
|
||||
for key in self:
|
||||
s.append(key)
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
"""Return a ragged shape with axes [utt][num_hyps].
|
||||
|
||||
Args:
|
||||
hyps:
|
||||
len(hyps) == batch_size. It contains the current hypothesis for
|
||||
each utterance in the batch.
|
||||
Returns:
|
||||
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
|
||||
the shape is on CPU.
|
||||
"""
|
||||
num_hyps = [len(h) for h in hyps]
|
||||
|
||||
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
|
||||
# to get exclusive sum later.
|
||||
num_hyps.insert(0, 0)
|
||||
|
||||
num_hyps = torch.tensor(num_hyps)
|
||||
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
|
||||
ans = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
) -> List[List[int]]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
|
||||
batch_size = encoder_out.size(0)
|
||||
T = encoder_out.size(1)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
for i in range(batch_size):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
|
||||
hyps_shape = _get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token != blank_id:
|
||||
new_ys.append(new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
ans = [h.ys[context_size:] for h in best_hyps]
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
It decodes only one utterance at a time. We keep it only for reference.
|
||||
The function :func:`modified_beam_search` should be preferred as it
|
||||
supports batch decoding.
|
||||
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
T = encoder_out.size(1)
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
|
||||
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||
# ys_log_probs is of shape (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyp in A],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_input is of shape (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_output is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
current_encoder_out = current_encoder_out.expand(
|
||||
decoder_out.size(0), 1, 1, -1
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
)
|
||||
# logits is of shape (num_hyps, 1, 1, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
# now logits is of shape (num_hyps, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||
|
||||
# topk_hyp_indexes are indexes into `A`
|
||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[i]
|
||||
if new_token != blank_id:
|
||||
new_ys.append(new_token)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
return ys
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
) -> List[int]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
sym_per_utt = 0
|
||||
|
||||
decoder_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
A = B
|
||||
B = HypothesisList()
|
||||
|
||||
joint_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
||||
# of hypotheses in A
|
||||
|
||||
while True:
|
||||
y_star = A.get_most_probable()
|
||||
A.remove(y_star)
|
||||
|
||||
cached_key = y_star.key
|
||||
|
||||
if cached_key not in decoder_cache:
|
||||
decoder_input = torch.tensor(
|
||||
[y_star.ys[-context_size:]],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
decoder_cache[cached_key] = decoder_out
|
||||
else:
|
||||
decoder_out = decoder_cache[cached_key]
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
log_prob = log_prob.squeeze()
|
||||
# Now log_prob is (vocab_size,)
|
||||
joint_cache[cached_key] = log_prob
|
||||
else:
|
||||
log_prob = joint_cache[cached_key]
|
||||
|
||||
# First, process the blank symbol
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
for i, v in zip(indices.tolist(), values.tolist()):
|
||||
if i == blank_id:
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
A_most_probable = A.get_most_probable()
|
||||
|
||||
kept_B = B.filter(A_most_probable.log_prob)
|
||||
|
||||
if len(kept_B) >= beam:
|
||||
B = kept_B.topk(beam)
|
||||
break
|
||||
|
||||
t += 1
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
return ys
|
||||
1071
egs/librispeech/ASR/pruned2_knowledge/conformer.py
Normal file
1071
egs/librispeech/ASR/pruned2_knowledge/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
547
egs/librispeech/ASR/pruned2_knowledge/decode.py
Executable file
547
egs/librispeech/ASR/pruned2_knowledge/decode.py
Executable file
@ -0,0 +1,547 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned2_knowledge/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned2_knowledge/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./pruned2_knowledge/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned2_knowledge/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned2_knowledge/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned2_knowledge/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned2_knowledge/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned2_knowledge/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned2_knowledge/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = model.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 100
|
||||
else:
|
||||
log_interval = 2
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
103
egs/librispeech/ASR/pruned2_knowledge/decoder.py
Normal file
103
egs/librispeech/ASR/pruned2_knowledge/decoder.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scaling import ScaledConv1d, ScaledEmbedding
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = ScaledEmbedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
if context_size > 1:
|
||||
self.conv = ScaledConv1d(
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
embedding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
||||
238
egs/librispeech/ASR/pruned2_knowledge/decoder2.py
Normal file
238
egs/librispeech/ASR/pruned2_knowledge/decoder2.py
Normal file
@ -0,0 +1,238 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
from subsampling import ScaledConv1d
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embedding_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
embedding_dim:
|
||||
Dimension of the input embedding.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding = ScaledEmbedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
if context_size > 1:
|
||||
self.conv = ScaledConv1d(
|
||||
in_channels=embedding_dim,
|
||||
out_channels=embedding_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=embedding_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
embedding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
return embedding_out
|
||||
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
||||
This module is often used to store word embeddings and retrieve them using indices.
|
||||
The input to the module is a list of indices, and the output is the corresponding
|
||||
word embeddings.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): size of the dictionary of embeddings
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
||||
(initialized to zeros) whenever it encounters the index.
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
||||
is renormalized to have norm :attr:`max_norm`.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
||||
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
||||
the words in the mini-batch. Default ``False``.
|
||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||
See Notes for more details regarding sparse gradients.
|
||||
|
||||
Attributes:
|
||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||
initialized from :math:`\mathcal{N}(0, 1)`
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
||||
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
||||
|
||||
.. note::
|
||||
Keep in mind that only a limited number of optimizers support
|
||||
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
||||
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
||||
|
||||
.. note::
|
||||
With :attr:`padding_idx` set, the embedding vector at
|
||||
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
||||
vector can be modified afterwards, e.g., using a customized
|
||||
initialization method, and thus changing the vector used to pad the
|
||||
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
||||
is always zero.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # an Embedding module containing 10 tensors of size 3
|
||||
>>> embedding = nn.Embedding(10, 3)
|
||||
>>> # a batch of 2 samples of 4 indices each
|
||||
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
||||
>>> embedding(input)
|
||||
tensor([[[-0.0251, -1.6902, 0.7172],
|
||||
[-0.6431, 0.0748, 0.6969],
|
||||
[ 1.4970, 1.3448, -0.9685],
|
||||
[-0.3677, -2.7265, -0.1685]],
|
||||
|
||||
[[ 1.4970, 1.3448, -0.9685],
|
||||
[ 0.4362, -0.4004, 0.9400],
|
||||
[-0.6431, 0.0748, 0.6969],
|
||||
[ 0.9124, -2.3616, 1.1151]]])
|
||||
|
||||
|
||||
>>> # example with padding_idx
|
||||
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
||||
>>> input = torch.LongTensor([[0,2,0,5]])
|
||||
>>> embedding(input)
|
||||
tensor([[[ 0.0000, 0.0000, 0.0000],
|
||||
[ 0.1535, -2.0309, 0.9315],
|
||||
[ 0.0000, 0.0000, 0.0000],
|
||||
[-0.1655, 0.9897, 0.0635]]])
|
||||
"""
|
||||
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
||||
'scale_grad_by_freq', 'sparse']
|
||||
|
||||
num_embeddings: int
|
||||
embedding_dim: int
|
||||
padding_idx: int
|
||||
scale_grad_by_freq: bool
|
||||
weight: Tensor
|
||||
sparse: bool
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
scale_speed: float = 5.0) -> None:
|
||||
super(ScaledEmbedding, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
elif padding_idx < 0:
|
||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
|
||||
self.scale_speed = scale_speed
|
||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
||||
self.sparse = sparse
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.reset_parameters()
|
||||
|
||||
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
nn.init.normal_(self.weight, std=0.05)
|
||||
nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed)
|
||||
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
scale = (self.scale * self.scale_speed).exp()
|
||||
if input.numel() < self.num_embeddings:
|
||||
return F.embedding(
|
||||
input, self.weight, self.padding_idx,
|
||||
None, 2.0, # None, 2.0 relate to normalization
|
||||
self.scale_grad_by_freq, self.sparse) * scale
|
||||
else:
|
||||
return F.embedding(
|
||||
input, self.weight * scale, self.padding_idx,
|
||||
None, 2.0, # None, 2.0 relates to normalization
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}'
|
||||
if self.padding_idx is not None:
|
||||
s += ', padding_idx={padding_idx}'
|
||||
if self.scale_grad_by_freq is not False:
|
||||
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
||||
if self.sparse is not False:
|
||||
s += ', sparse=True'
|
||||
return s.format(**self.__dict__)
|
||||
43
egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py
Normal file
43
egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EncoderInterface(nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (batch_size, input_seq_len, num_features)
|
||||
containing the input features.
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames
|
||||
in `x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing two tensors:
|
||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
||||
containing unnormalized probabilities, i.e., the output of a
|
||||
linear layer.
|
||||
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
||||
the number of frames in `encoder_out` before padding.
|
||||
"""
|
||||
raise NotImplementedError("Please implement it in a subclass")
|
||||
182
egs/librispeech/ASR/pruned2_knowledge/export.py
Executable file
182
egs/librispeech/ASR/pruned2_knowledge/export.py
Executable file
@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./pruned2_knowledge/export.py \
|
||||
--exp-dir ./pruned2_knowledge/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `pruned2_knowledge/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned2_knowledge/decode.py \
|
||||
--exp-dir ./pruned2_knowledge/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 100 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned2_knowledge/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
67
egs/librispeech/ASR/pruned2_knowledge/joiner.py
Normal file
67
egs/librispeech/ASR/pruned2_knowledge/joiner.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim)
|
||||
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
decoder_out
|
||||
)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return logit
|
||||
193
egs/librispeech/ASR/pruned2_knowledge/model.py
Normal file
193
egs/librispeech/ASR/pruned2_knowledge/model.py
Normal file
@ -0,0 +1,193 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_speed=0.5
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
warmup:
|
||||
A value warmup >= 0 that determines which modules are active, values
|
||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
331
egs/librispeech/ASR/pruned2_knowledge/optim.py
Normal file
331
egs/librispeech/ASR/pruned2_knowledge/optim.py
Normal file
@ -0,0 +1,331 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class Eve(Optimizer):
|
||||
r"""
|
||||
Implements Eve algorithm. This is a modified version of AdamW with a special
|
||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
||||
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
||||
will be close to invariant to the absolute scale on the parameter matrix.
|
||||
|
||||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
||||
Eve is unpublished so far.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
||||
this value means that the weight would decay significantly after
|
||||
about 3k minibatches. Is not multiplied by learning rate, but
|
||||
is conditional on RMS-value of parameter being > target_rms.
|
||||
target_rms (float, optional): target root-mean-square value of
|
||||
parameters, if they fall below this we will stop applying weight decay.
|
||||
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-3,
|
||||
target_rms=0.1,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 0: {}".format(betas[0])
|
||||
)
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
if not 0 <= weight_decay <= 0.1:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
)
|
||||
if not 0 < target_rms <= 10.0:
|
||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
target_rms=target_rms,
|
||||
)
|
||||
super(Eve, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Eve, self).__setstate__(state)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
# Perform optimization step
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"AdamW does not support sparse gradients"
|
||||
)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
||||
group["eps"]
|
||||
)
|
||||
|
||||
step_size = group["lr"] / bias_correction1
|
||||
target_rms = group["target_rms"]
|
||||
weight_decay = group["weight_decay"]
|
||||
|
||||
if p.numel() > 1:
|
||||
# avoid applying this weight-decay on "scaling factors"
|
||||
# (which are scalar).
|
||||
is_above_target_rms = p.norm() > (
|
||||
target_rms * (p.numel() ** 0.5)
|
||||
)
|
||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class LRScheduler(object):
|
||||
"""
|
||||
Base-class for learning rate schedulers where the learning-rate depends on both the
|
||||
batch and the epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
||||
# Attach optimizer
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError(
|
||||
"{} is not an Optimizer".format(type(optimizer).__name__)
|
||||
)
|
||||
self.optimizer = optimizer
|
||||
self.verbose = verbose
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
|
||||
self.base_lrs = [
|
||||
group["initial_lr"] for group in optimizer.param_groups
|
||||
]
|
||||
|
||||
self.epoch = 0
|
||||
self.batch = 0
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a :class:`dict`.
|
||||
|
||||
It contains an entry for every variable in self.__dict__ which
|
||||
is not the optimizer.
|
||||
"""
|
||||
return {
|
||||
"base_lrs": self.base_lrs,
|
||||
"epoch": self.epoch,
|
||||
"batch": self.batch,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the schedulers state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): scheduler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_last_lr(self) -> List[float]:
|
||||
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
||||
return self._last_lr
|
||||
|
||||
def get_lr(self):
|
||||
# Compute list of learning rates from self.epoch and self.batch and
|
||||
# self.base_lrs; this must be overloaded by the user.
|
||||
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
||||
raise NotImplementedError
|
||||
|
||||
def step_batch(self, batch: Optional[int] = None) -> None:
|
||||
# Step the batch index, or just set it. If `batch` is specified, it
|
||||
# must be the batch index from the start of training, i.e. summed over
|
||||
# all epochs.
|
||||
# You can call this in any order; if you don't provide 'batch', it should
|
||||
# of course be called once per batch.
|
||||
if batch is not None:
|
||||
self.batch = batch
|
||||
else:
|
||||
self.batch = self.batch + 1
|
||||
self._set_lrs()
|
||||
|
||||
def step_epoch(self, epoch: Optional[int] = None):
|
||||
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
||||
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
||||
# arg, you should call it at the end of the epoch.
|
||||
if epoch is not None:
|
||||
self.epoch = epoch
|
||||
else:
|
||||
self.epoch = self.epoch + 1
|
||||
self._set_lrs()
|
||||
|
||||
def _set_lrs(self):
|
||||
values = self.get_lr()
|
||||
assert len(values) == len(self.optimizer.param_groups)
|
||||
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||
param_group, lr = data
|
||||
param_group["lr"] = lr
|
||||
self.print_lr(self.verbose, i, lr)
|
||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
def print_lr(self, is_verbose, group, lr):
|
||||
"""Display the current learning rate."""
|
||||
if is_verbose:
|
||||
print(
|
||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||
f" of group {group} to {lr:.4e}."
|
||||
)
|
||||
|
||||
|
||||
class Eden(LRScheduler):
|
||||
"""
|
||||
Eden scheduler.
|
||||
lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
||||
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
|
||||
|
||||
E.g. suggest initial-lr = 0.003 (passed to optimizer).
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer to change the learning rates on
|
||||
lr_batches: the number of batches after which we start significantly
|
||||
decreasing the learning rate, suggest 5000.
|
||||
lr_epochs: the number of epochs after which we start significantly
|
||||
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
||||
20 to 40 epochs, but may need smaller number if dataset is huge
|
||||
and you will do few epochs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
lr_batches: Union[int, float],
|
||||
lr_epochs: Union[int, float],
|
||||
verbose: bool = False,
|
||||
):
|
||||
super(Eden, self).__init__(optimizer, verbose)
|
||||
self.lr_batches = lr_batches
|
||||
self.lr_epochs = lr_epochs
|
||||
|
||||
def get_lr(self):
|
||||
factor = (
|
||||
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
|
||||
) ** -0.25 * (
|
||||
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
||||
** -0.25
|
||||
)
|
||||
return [x * factor for x in self.base_lrs]
|
||||
|
||||
|
||||
def _test_eden():
|
||||
m = torch.nn.Linear(100, 100)
|
||||
optim = Eve(m.parameters(), lr=0.003)
|
||||
|
||||
scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True)
|
||||
|
||||
for epoch in range(10):
|
||||
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
||||
|
||||
for step in range(20):
|
||||
x = torch.randn(200, 100).detach()
|
||||
x.requires_grad = True
|
||||
y = m(x)
|
||||
dy = torch.randn(200, 100).detach()
|
||||
f = (y * dy).sum()
|
||||
f.backward()
|
||||
|
||||
optim.step()
|
||||
scheduler.step_batch()
|
||||
optim.zero_grad()
|
||||
print("last lr = ", scheduler.get_last_lr())
|
||||
print("state dict = ", scheduler.state_dict())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_eden()
|
||||
332
egs/librispeech/ASR/pruned2_knowledge/sampling.py
Normal file
332
egs/librispeech/ASR/pruned2_knowledge/sampling.py
Normal file
@ -0,0 +1,332 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py,
|
||||
# its git history is there.
|
||||
|
||||
import timeit
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd
|
||||
from typing import Tuple, Optional
|
||||
from scaling import ScaledLinear
|
||||
import random
|
||||
from torch_scheduled_sampling import sample_combined
|
||||
|
||||
# The main exports of this file are the module KnowledgeBaseLookup and the
|
||||
# function create_knowledge_base.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter:
|
||||
std = 0.1
|
||||
a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of
|
||||
# 0.1 from uniform distribution
|
||||
ans = nn.Parameter(torch.ones(M ** N, D))
|
||||
nn.init.uniform_(ans, -a, a)
|
||||
return ans
|
||||
|
||||
def join_indexes(indexes: Tensor, M: int) -> Tensor:
|
||||
"""
|
||||
Combines N-tuples of indexes into single indexes that can be used for
|
||||
lookup in the knowledge base. Args:
|
||||
indexes: tensor of torch.int64 of shape (*, K, N), with elements in
|
||||
{0..M-1}
|
||||
M: the size of the original softmaxes, is upper bound on elements
|
||||
in indexes
|
||||
Returns:
|
||||
joined_indexes: of shape (*, K), joined_indexes[...,k] equals
|
||||
joined_indexes[...,0,k] + joined_indexes[...,1,k]*(M**1) ... + joined_indexes[...,1,k]*(M**(N-1))]
|
||||
"""
|
||||
N = indexes.shape[-1]
|
||||
n_powers = M ** torch.arange(N, device=indexes.device) # [ 1, M, ..., M**(N-1) ]
|
||||
return (indexes * n_powers).sum(dim=-1)
|
||||
|
||||
|
||||
# Note, we don't use this, we
|
||||
def weighted_matrix_lookup(weights: Tensor,
|
||||
indexes: Tensor,
|
||||
knowledge_base: Tensor) -> Tensor:
|
||||
"""
|
||||
Weighted combination of specified rows of a matrix.
|
||||
weights: Tensor of shape (*, K), can contain any value but probably in [0..1].
|
||||
indexes: Tensor of shape (*, K), with elements in [0..C-1]
|
||||
knowledge_base: Tensor of shape (C-1, D), whose rows we'll be looking up
|
||||
Returns:
|
||||
tensor of shape (*, D), containing weighted sums of rows of
|
||||
`knowledge_base`
|
||||
"""
|
||||
if True:
|
||||
return WeightedMatrixLookupFunction.apply(weights, indexes, knowledge_base)
|
||||
else:
|
||||
# simpler but less memory-efficient implementation
|
||||
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
|
||||
D = knowledge_base.shape[-1]
|
||||
weights = weights.unsqueeze(-2) # (*, 1, K)
|
||||
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
|
||||
ans = torch.matmul(weights, lookup) # ans: (*, 1, D)
|
||||
ans = ans.squeeze(-2)
|
||||
assert list(ans.shape) == list(weights.shape[:-2]) + [D]
|
||||
return ans
|
||||
|
||||
|
||||
class WeightedMatrixLookupFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor:
|
||||
"""
|
||||
Weighted combination of specified rows of a matrix.
|
||||
weights: Tensor of shape (*, K), can contain any value but probably in [0..1].
|
||||
indexes: Tensor of shape (*, K), with elements in [0..C-1]
|
||||
knowledge_base: Tensor of shape (C, D), whose rows we'll be looking up
|
||||
Returns:
|
||||
tensor of shape (*, D), containing weighted sums of rows of
|
||||
`knowledge_base`
|
||||
"""
|
||||
if random.random() < 0.001:
|
||||
print("dtype[1] = ", weights.dtype)
|
||||
ctx.save_for_backward(weights.detach(), indexes.detach(),
|
||||
knowledge_base.detach())
|
||||
with torch.no_grad():
|
||||
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
|
||||
D = knowledge_base.shape[-1]
|
||||
weights = weights.unsqueeze(-2) # (*, 1, K)
|
||||
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
|
||||
ans = torch.matmul(weights, lookup) # ans: (*, 1, D)
|
||||
ans = ans.squeeze(-2) #(*, D)
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]:
|
||||
# ans_grad: (*, D)
|
||||
weights, indexes, knowledge_base = ctx.saved_tensors
|
||||
knowledge_base.requires_grad = True
|
||||
dtype = ans_grad.dtype
|
||||
ans_grad = ans_grad.to(weights.dtype)
|
||||
assert weights.requires_grad == False
|
||||
D = knowledge_base.shape[-1]
|
||||
with torch.enable_grad():
|
||||
# we'll use torch's autograd to differentiate this operation, which
|
||||
# is nontrivial [and anyway we need `lookup` to compute weight grad.
|
||||
# We don't save `lookup` because it's large, that is the reason
|
||||
# we override Torch autograd.
|
||||
lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten())
|
||||
lookup = lookup.reshape(*indexes.shape, D) # (*, K, D)
|
||||
weights = weights.unsqueeze(-1) # (*, K, 1)
|
||||
# forward pass: was:
|
||||
## ans = torch.matmul(weights, lookup)
|
||||
## ans: (*, 1, D)
|
||||
## ans = ans.squeeze(-2) # ans, ans_grad: (*, D)
|
||||
weights_grad = torch.matmul(lookup, # (*, K, D)
|
||||
ans_grad.unsqueeze(-1)) # (*, D, 1)
|
||||
weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K)
|
||||
lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D)
|
||||
lookup.backward(gradient=lookup_grad)
|
||||
return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype)
|
||||
|
||||
|
||||
class PenalizeNegentropyFunction(torch.autograd.Function):
|
||||
"""
|
||||
Function that does nothing in forward pass, but in backprop, it is as
|
||||
if you had added: `- tot_entropy * alpha` to the loss function, where
|
||||
tot_entropy is the the entropy of the average of the input distributions,
|
||||
times the number of input distributions. (We multiply by this because
|
||||
our overall loss function is proportional to the number of frames).
|
||||
|
||||
This will tend to make the entropy want to become as large as possible,
|
||||
making (-tot_entropy * alpha) as negative as possible.
|
||||
|
||||
Args:
|
||||
logprobs: Tensor of shape (*, num_classes), should be the result of
|
||||
calling some_tensor.log_softmax(dim=-1)
|
||||
Returns:
|
||||
logprobs
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, logprobs: Tensor, alpha: float):
|
||||
ctx.save_for_backward(logprobs.detach())
|
||||
ctx.alpha = alpha
|
||||
return logprobs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]:
|
||||
logprobs, = ctx.saved_tensors
|
||||
with torch.enable_grad():
|
||||
logprobs.requires_grad = True
|
||||
# `negentropy` is the negative entropy of the average distribution.
|
||||
# distributions. It will be <= 0.
|
||||
l = logprobs.reshape(-1, logprobs.shape[-1])
|
||||
scale = ctx.alpha * l.shape[0]
|
||||
avg_dist = l.exp().mean(dim=0)
|
||||
negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum()
|
||||
if random.random() < 0.0005:
|
||||
negentropy_individual = (l * l.exp()).sum(dim=-1).mean()
|
||||
print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item())
|
||||
loss = negentropy * scale
|
||||
loss.backward()
|
||||
return logprobs_grad + logprobs.grad, None
|
||||
|
||||
|
||||
class KnowledgeBaseLookup(nn.Module):
|
||||
"""
|
||||
Create knowledge-base lookup module. (The knowledge-base parameter, which is
|
||||
large, is shared between these modules).
|
||||
Args:
|
||||
M: int, softmax size, e.g. in [32..128]
|
||||
N: int, number of softmaxes, in [2..3]
|
||||
D: int, embedding dimension in knowledge base, e.g. 256
|
||||
K: number of samples (affects speed/accuracy tradeoff), e.g. 16.
|
||||
embedding_dim: the dimension to project from and to, e.g. the
|
||||
d_model of the conformer.
|
||||
"""
|
||||
def __init__(self, M: int, N: int, D: int,
|
||||
K: int, embedding_dim: int,
|
||||
knowledge_base: nn.Parameter,
|
||||
negentropy_penalty: float = 0.001):
|
||||
super(KnowledgeBaseLookup, self).__init__()
|
||||
self.knowledge_base = knowledge_base # shared!
|
||||
self.in_proj = ScaledLinear(embedding_dim, M * N,
|
||||
initial_scale=1.0)
|
||||
# initial_scale = 4.0 because the knowlege_base activations are
|
||||
# quite small -- if we use our optimizer they'll have stddev <= 0.1.
|
||||
self.out_proj = ScaledLinear(D, embedding_dim,
|
||||
initial_scale = 4.0)
|
||||
self.M = M
|
||||
self.N = N
|
||||
self.K = K
|
||||
self.negentropy_penalty = negentropy_penalty
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Forward function that does knowledge-base lookup.
|
||||
Args:
|
||||
x: input, of shape (*, E) where E is embedding_dim
|
||||
as passed to constructor
|
||||
y: output of knowledge-base lookup, of shape (*, E)
|
||||
|
||||
# TODO: later we can try multiplying by a projection of x or something like that.
|
||||
"""
|
||||
x = self.in_proj(x) # now (*, M*N)
|
||||
x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M)
|
||||
x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M)
|
||||
x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty)
|
||||
|
||||
_, indexes, weights = sample_combined(x, self.K, input_is_log=True)
|
||||
x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D)
|
||||
x = self.out_proj(x) # now (*, self.embedding_dim)
|
||||
return x
|
||||
|
||||
|
||||
def _test_knowledge_base_lookup():
|
||||
K = 16
|
||||
N = 2
|
||||
M = 128
|
||||
D = 256
|
||||
E = 255
|
||||
|
||||
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
|
||||
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
|
||||
|
||||
B = 30
|
||||
T = 40
|
||||
x = torch.randn(B, T, E)
|
||||
x.requires_grad = True
|
||||
y = m(x)
|
||||
assert y.shape == x.shape
|
||||
y.sum().backward() # make sure backward doesn't crash..
|
||||
print("y = ", y)
|
||||
print("x.grad = ", x.grad)
|
||||
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
||||
|
||||
dtype = torch.float32
|
||||
device = torch.device('cuda')
|
||||
train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ]
|
||||
from optim import Eve
|
||||
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
|
||||
m = m.to(device).to(dtype)
|
||||
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
||||
# Epoch 0, batch 0, loss 1.0109944343566895
|
||||
# Epoch 10, batch 0, loss 1.0146660804748535
|
||||
# Epoch 20, batch 0, loss 1.0119813680648804
|
||||
# Epoch 30, batch 0, loss 1.0105408430099487
|
||||
# Epoch 40, batch 0, loss 1.0077732801437378
|
||||
# Epoch 50, batch 0, loss 1.0050103664398193
|
||||
# Epoch 60, batch 0, loss 1.0033129453659058
|
||||
# Epoch 70, batch 0, loss 1.0014232397079468
|
||||
# Epoch 80, batch 0, loss 0.9977912306785583
|
||||
# Epoch 90, batch 0, loss 0.8274348974227905
|
||||
# Epoch 100, batch 0, loss 0.3368612825870514
|
||||
# Epoch 110, batch 0, loss 0.11323091387748718
|
||||
# Time taken: 17.591704960912466
|
||||
for epoch in range(150):
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
if n % 10 == 0 and epoch % 10 == 0:
|
||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
stop = timeit.default_timer()
|
||||
print('Time taken: ', stop - start)
|
||||
|
||||
def _test_knowledge_base_lookup_autocast():
|
||||
K = 16
|
||||
N = 2
|
||||
M = 128
|
||||
D = 256
|
||||
E = 255
|
||||
|
||||
knowledge_base: nn.Parameter = create_knowledge_base(M, N, D)
|
||||
m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base)
|
||||
|
||||
B = 30
|
||||
T = 40
|
||||
x = torch.randn(B, T, E)
|
||||
x.requires_grad = True
|
||||
y = m(x)
|
||||
assert y.shape == x.shape
|
||||
y.sum().backward() # make sure backward doesn't crash..
|
||||
print("y = ", y)
|
||||
print("x.grad = ", x.grad)
|
||||
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
||||
|
||||
device = torch.device('cuda')
|
||||
train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ]
|
||||
from optim import Eve
|
||||
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
|
||||
m = m.to(device)
|
||||
|
||||
scaler = GradScaler(enabled=True)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
||||
|
||||
for epoch in range(150):
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
if n % 10 == 0 and epoch % 10 == 0:
|
||||
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
stop = timeit.default_timer()
|
||||
print('Time taken: ', stop - start)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_knowledge_base_lookup()
|
||||
_test_knowledge_base_lookup_autocast()
|
||||
707
egs/librispeech/ASR/pruned2_knowledge/scaling.py
Normal file
707
egs/librispeech/ASR/pruned2_knowledge/scaling.py
Normal file
@ -0,0 +1,707 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import collections
|
||||
from itertools import repeat
|
||||
from typing import Optional, Tuple
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
channel_dim: int,
|
||||
min_positive: float, # e.g. 0.05
|
||||
max_positive: float, # e.g. 0.95
|
||||
max_factor: float, # e.g. 0.01
|
||||
min_abs: float, # e.g. 0.2
|
||||
max_abs: float, # e.g. 100.0
|
||||
) -> Tensor:
|
||||
if x.requires_grad:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
xgt0 = x > 0
|
||||
proportion_positive = torch.mean(
|
||||
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
||||
)
|
||||
factor1 = (
|
||||
(min_positive - proportion_positive).relu()
|
||||
* (max_factor / min_positive)
|
||||
if min_positive != 0.0
|
||||
else 0.0
|
||||
)
|
||||
factor2 = (
|
||||
(proportion_positive - max_positive).relu()
|
||||
* (max_factor / (max_positive - 1.0))
|
||||
if max_positive != 1.0
|
||||
else 0.0
|
||||
)
|
||||
factor = factor1 + factor2
|
||||
if isinstance(factor, float):
|
||||
factor = torch.zeros_like(proportion_positive)
|
||||
|
||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||
below_threshold = mean_abs < min_abs
|
||||
above_threshold = mean_abs > max_abs
|
||||
|
||||
ctx.save_for_backward(
|
||||
factor, xgt0, below_threshold, above_threshold
|
||||
)
|
||||
ctx.max_factor = max_factor
|
||||
ctx.sum_dims = sum_dims
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||
dtype = x_grad.dtype
|
||||
scale_factor = (
|
||||
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
||||
* (xgt0.to(dtype) - 0.5)
|
||||
* (ctx.max_factor * 2.0)
|
||||
)
|
||||
|
||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
LayerNorm. The observation this is based on, is that Transformer-type
|
||||
networks, especially with pre-norm, sometimes seem to set one of the
|
||||
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
||||
the LayerNorm because the output magnitude is then not strongly dependent
|
||||
on the other (useful) features. Presumably the weight and bias of the
|
||||
LayerNorm are required to allow it to do this.
|
||||
|
||||
So the idea is to introduce this large constant value as an explicit
|
||||
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
||||
doesn't have to do this trick. We make the "eps" learnable.
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interprted as an offset from the input's ndim if negative.
|
||||
shis is NOT the num_channels; it should typically be one of
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
) -> None:
|
||||
super(BasicNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||
+ self.eps.exp()
|
||||
) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
class ScaledLinear(nn.Linear):
|
||||
"""
|
||||
A modified version of nn.Linear where the parameters are scaled before
|
||||
use, via:
|
||||
weight = self.weight * self.weight_scale.exp()
|
||||
bias = self.bias * self.bias_scale.exp()
|
||||
|
||||
Args:
|
||||
Accepts the standard args and kwargs that nn.Linear accepts
|
||||
e.g. in_features, out_features, bias=False.
|
||||
|
||||
initial_scale: you can override this if you want to increase
|
||||
or decrease the initial magnitude of the module's output
|
||||
(affects the initialization of weight_scale and bias_scale).
|
||||
Another option, if you want to do something like this, is
|
||||
to re-initialize the parameters.
|
||||
initial_speed: this affects how fast the parameter will
|
||||
learn near the start of training; you can set it to a
|
||||
value less than one if you suspect that a module
|
||||
is contributing to instability near the start of training.
|
||||
Nnote: regardless of the use of this option, it's best to
|
||||
use schedulers like Noam that have a warm-up period.
|
||||
Alternatively you can set it to more than 1 if you want it to
|
||||
initially train faster. Must be greater than 0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter("bias_scale", None)
|
||||
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in nn.Linear
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return torch.nn.functional.linear(
|
||||
input, self.get_weight(), self.get_bias()
|
||||
)
|
||||
|
||||
|
||||
class ScaledConv1d(nn.Conv1d):
|
||||
# See docs for ScaledLinear
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter("bias_scale", None)
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != "zeros":
|
||||
return F.conv1d(
|
||||
F.pad(
|
||||
input,
|
||||
self._reversed_padding_repeated_twice,
|
||||
mode=self.padding_mode,
|
||||
),
|
||||
self.get_weight(),
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
_single(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return F.conv1d(
|
||||
input,
|
||||
self.get_weight(),
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
class ScaledConv2d(nn.Conv2d):
|
||||
# See docs for ScaledLinear
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter("bias_scale", None)
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != "zeros":
|
||||
return F.conv2d(
|
||||
F.pad(
|
||||
input,
|
||||
self._reversed_padding_repeated_twice,
|
||||
mode=self.padding_mode,
|
||||
),
|
||||
weight,
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
_pair(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return F.conv2d(
|
||||
input,
|
||||
weight,
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self._conv_forward(input, self.get_weight())
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||
each channel, that it is positive at least a proportion `threshold` of the
|
||||
time. It does this by multiplying negative derivative values by up to
|
||||
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
||||
interpolated from 1 at the threshold to those extremal values when none
|
||||
of the inputs are positive.
|
||||
|
||||
|
||||
Args:
|
||||
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||
min_positive: the minimum, per channel, of the proportion of the time
|
||||
that (x > 0), below which we start to modify the derivatives.
|
||||
max_positive: the maximum, per channel, of the proportion of the time
|
||||
that (x > 0), above which we start to modify the derivatives.
|
||||
max_factor: the maximum factor by which we modify the derivatives for
|
||||
either the sign constraint or the magnitude constraint;
|
||||
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
||||
values in the range [0.98..1.02].
|
||||
min_abs: the minimum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
max_abs: the maximum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
self.max_positive = max_positive
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
"""
|
||||
double_swish(x) = x * torch.sigmoid(x-1)
|
||||
This is a definition, originally motivated by its close numerical
|
||||
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
||||
|
||||
Memory-efficient derivative computation:
|
||||
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
||||
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
||||
Now, s'(x) = s(x) * (1-s(x)).
|
||||
double_swish'(x) = x * s'(x) + s(x).
|
||||
= x * s(x) * (1-s(x)) + s(x).
|
||||
= double_swish(x) * (1-s(x)) + s(x)
|
||||
... so we just need to remember s(x) but not x itself.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
x = x.detach()
|
||||
s = torch.sigmoid(x - 1.0)
|
||||
y = x * s
|
||||
ctx.save_for_backward(s, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
s, y = ctx.saved_tensors
|
||||
return (y * (1 - s) + s) * y_grad
|
||||
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
||||
on the parameters. Note: due to how we initialize it, it's best used with
|
||||
schedulers like Noam that have a warmup period.
|
||||
|
||||
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
||||
This module is often used to store word embeddings and retrieve them using indices.
|
||||
The input to the module is a list of indices, and the output is the corresponding
|
||||
word embeddings.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): size of the dictionary of embeddings
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
||||
(initialized to zeros) whenever it encounters the index.
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
||||
is renormalized to have norm :attr:`max_norm`.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
||||
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
||||
the words in the mini-batch. Default ``False``.
|
||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||
See Notes for more details regarding sparse gradients.
|
||||
|
||||
initial_speed (float, optional): This affects how fast the parameter will
|
||||
learn near the start of training; you can set it to a value less than
|
||||
one if you suspect that a module is contributing to instability near
|
||||
the start of training. Nnote: regardless of the use of this option,
|
||||
it's best to use schedulers like Noam that have a warm-up period.
|
||||
Alternatively you can set it to more than 1 if you want it to
|
||||
initially train faster. Must be greater than 0.
|
||||
|
||||
|
||||
Attributes:
|
||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||
initialized from :math:`\mathcal{N}(0, 1)`
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
||||
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
||||
|
||||
.. note::
|
||||
Keep in mind that only a limited number of optimizers support
|
||||
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
||||
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
||||
|
||||
.. note::
|
||||
With :attr:`padding_idx` set, the embedding vector at
|
||||
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
||||
vector can be modified afterwards, e.g., using a customized
|
||||
initialization method, and thus changing the vector used to pad the
|
||||
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
||||
is always zero.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # an Embedding module containing 10 tensors of size 3
|
||||
>>> embedding = nn.Embedding(10, 3)
|
||||
>>> # a batch of 2 samples of 4 indices each
|
||||
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
||||
>>> embedding(input)
|
||||
tensor([[[-0.0251, -1.6902, 0.7172],
|
||||
[-0.6431, 0.0748, 0.6969],
|
||||
[ 1.4970, 1.3448, -0.9685],
|
||||
[-0.3677, -2.7265, -0.1685]],
|
||||
|
||||
[[ 1.4970, 1.3448, -0.9685],
|
||||
[ 0.4362, -0.4004, 0.9400],
|
||||
[-0.6431, 0.0748, 0.6969],
|
||||
[ 0.9124, -2.3616, 1.1151]]])
|
||||
|
||||
|
||||
>>> # example with padding_idx
|
||||
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
||||
>>> input = torch.LongTensor([[0,2,0,5]])
|
||||
>>> embedding(input)
|
||||
tensor([[[ 0.0000, 0.0000, 0.0000],
|
||||
[ 0.1535, -2.0309, 0.9315],
|
||||
[ 0.0000, 0.0000, 0.0000],
|
||||
[-0.1655, 0.9897, 0.0635]]])
|
||||
|
||||
"""
|
||||
__constants__ = [
|
||||
"num_embeddings",
|
||||
"embedding_dim",
|
||||
"padding_idx",
|
||||
"scale_grad_by_freq",
|
||||
"sparse",
|
||||
]
|
||||
|
||||
num_embeddings: int
|
||||
embedding_dim: int
|
||||
padding_idx: int
|
||||
scale_grad_by_freq: bool
|
||||
weight: Tensor
|
||||
sparse: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: Optional[int] = None,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
initial_speed: float = 1.0,
|
||||
) -> None:
|
||||
super(ScaledEmbedding, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert (
|
||||
padding_idx < self.num_embeddings
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
elif padding_idx < 0:
|
||||
assert (
|
||||
padding_idx >= -self.num_embeddings
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
|
||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
||||
self.sparse = sparse
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.reset_parameters(initial_speed)
|
||||
|
||||
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
||||
std = 0.1 / initial_speed
|
||||
nn.init.normal_(self.weight, std=std)
|
||||
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
|
||||
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
scale = self.scale.exp()
|
||||
if input.numel() < self.num_embeddings:
|
||||
return (
|
||||
F.embedding(
|
||||
input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
None,
|
||||
2.0, # None, 2.0 relate to normalization
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
return F.embedding(
|
||||
input,
|
||||
self.weight * scale,
|
||||
self.padding_idx,
|
||||
None,
|
||||
2.0, # None, 2.0 relates to normalization
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
||||
if self.padding_idx is not None:
|
||||
s += ", padding_idx={padding_idx}"
|
||||
if self.scale_grad_by_freq is not False:
|
||||
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
||||
if self.sparse is not False:
|
||||
s += ", sparse=True"
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
def _test_activation_balancer_sign():
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(
|
||||
channel_dim=0,
|
||||
min_positive=0.05,
|
||||
max_positive=0.95,
|
||||
max_factor=0.2,
|
||||
min_abs=0.0,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||
|
||||
y = m(x)
|
||||
y.backward(gradient=y_grad)
|
||||
print("_test_activation_balancer_sign: x = ", x)
|
||||
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
||||
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
||||
|
||||
|
||||
def _test_activation_balancer_magnitude():
|
||||
magnitudes = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
||||
-1
|
||||
)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(
|
||||
channel_dim=0,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
max_factor=0.2,
|
||||
min_abs=0.2,
|
||||
max_abs=0.8,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||
|
||||
y = m(x)
|
||||
y.backward(gradient=y_grad)
|
||||
print("_test_activation_balancer_magnitude: x = ", x)
|
||||
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
||||
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
||||
|
||||
|
||||
def _test_basic_norm():
|
||||
num_channels = 128
|
||||
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
||||
|
||||
x = torch.randn(500, num_channels)
|
||||
|
||||
y = m(x)
|
||||
|
||||
assert y.shape == x.shape
|
||||
x_rms = (x ** 2).mean().sqrt()
|
||||
y_rms = (y ** 2).mean().sqrt()
|
||||
print("x rms = ", x_rms)
|
||||
print("y rms = ", y_rms)
|
||||
assert y_rms < x_rms
|
||||
assert y_rms > 0.5 * x_rms
|
||||
|
||||
|
||||
def _test_double_swish_deriv():
|
||||
x = torch.randn(10, 12, dtype=torch.double) * 0.5
|
||||
x.requires_grad = True
|
||||
m = DoubleSwish()
|
||||
torch.autograd.gradcheck(m, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
_test_double_swish_deriv()
|
||||
628
egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py
Normal file
628
egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py
Normal file
@ -0,0 +1,628 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
|
||||
def _activation_balancer_loss(mean_pos: Tensor,
|
||||
mean_neg: Tensor,
|
||||
min_positive: float, # e.g. 0.05
|
||||
max_positive: float, # e.g. 0.95
|
||||
max_factor: float, # e.g. 0.01
|
||||
min_abs: float, # e.g. 0.2
|
||||
max_abs: float, # e.g. 100.0
|
||||
eps: float = 1.0e-10):
|
||||
"""
|
||||
Returns a loss-function for the ActivationBalancer module. This loss
|
||||
function is not exposed to the user but is used internally, and eventually
|
||||
its derivatives are scaled by some heuristic related to derivative magnitudes,
|
||||
and added to the backpropped deriv.
|
||||
|
||||
Args:
|
||||
mean_pos: a Tensor of arbitrary dimension, probably something like (1, num_channels, 1, 1),
|
||||
containing the mean of only the positive parts of the input features, i.e.
|
||||
of x.relu().
|
||||
mean_neg: a Tensor of arbitrary dimension, probably something like (1, num_channels, 1, 1),
|
||||
containing the mean of only the negative parts of the input features, i.e.
|
||||
of (-x).relu().
|
||||
min_positive: the minimum allowed value of mean_pos / (mean_pos + mean_neg) before we
|
||||
start penalizing.
|
||||
max_positive: the maximum allowed value of mean_pos / (mean_pos + mean_neg) before we
|
||||
start penalizing.
|
||||
"""
|
||||
loss_parts = []
|
||||
|
||||
x_mean = mean_positive - mean_negative
|
||||
x_mean_abs = (mean_positive + mean_negative + eps).detach()
|
||||
x_rel_mean= x_mean / x_mean_abs
|
||||
|
||||
if min_positive != 0.0:
|
||||
# e.g. x_mean_floor = -0.95 + 0.05 = -0.9
|
||||
x_rel_mean_floor = (-(1-min_positive) + min_positive)
|
||||
min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive))
|
||||
# this part of the loss would be 1.0 * num_channels if all these constraints were
|
||||
# 100% violated.
|
||||
loss_parts.append(min_positive_loss)
|
||||
|
||||
if max_positive != 1.0:
|
||||
# e.g. x_mean_floor = -0.05 + 0.95 = 0.8
|
||||
x_rel_mean_ceil = - (1.0-max_positive) + max_positive
|
||||
max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil))
|
||||
# this part of the loss would be 1.0 * num_channels if all these constraints were
|
||||
# 100% violated.
|
||||
loss_parts.append(max_positive_loss)
|
||||
|
||||
if min_abs != 0.0:
|
||||
min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs
|
||||
# this part of the loss would be 1.0 * num_channels if all these constraints were
|
||||
# 100% violated.
|
||||
loss_parts.append(min_abs_loss)
|
||||
|
||||
if max_abs != 0.0:
|
||||
max_abs_loss = (x_mean_abs / max_abs).log().relu()
|
||||
# this part of the loss would be [something logarithmic] * num_channels if all these constraints were
|
||||
# 100% violated.
|
||||
loss_parts.append(max_abs_loss)
|
||||
|
||||
|
||||
# the min_positive and 1 - max_positive are "ballast" added to the
|
||||
denom = mean_pos + mean_neg + (min_positive + (1 - max_positive))
|
||||
num
|
||||
|
||||
if min_positive != 0.0:
|
||||
|
||||
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor,
|
||||
channel_dim: int,
|
||||
min_positive: float, # e.g. 0.05
|
||||
max_positive: float, # e.g. 0.95
|
||||
max_factor: float, # e.g. 0.01
|
||||
min_abs: float, # e.g. 0.2
|
||||
max_abs: float, # e.g. 100.0
|
||||
) -> Tensor:
|
||||
if x.requires_grad:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
xgt0 = x > 0
|
||||
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
||||
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
|
||||
if min_positive != 0.0 else 0.0)
|
||||
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
|
||||
if max_positive != 1.0 else 0.0)
|
||||
factor = factor1 + factor2
|
||||
if isinstance(factor, float):
|
||||
factor = torch.zeros_like(proportion_positive)
|
||||
|
||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||
below_threshold = (mean_abs < min_abs)
|
||||
above_threshold = (mean_abs > max_abs)
|
||||
|
||||
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
||||
ctx.max_factor = max_factor
|
||||
ctx.sum_dims = sum_dims
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||
dtype = x_grad.dtype
|
||||
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
||||
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
||||
|
||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
LayerNorm. The observation this is based on, is that Transformer-type
|
||||
networks, especially with pre-norm, sometimes seem to set one of the
|
||||
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
||||
the LayerNorm because the output magnitude is then not strongly dependent
|
||||
on the other (useful) features. Presumably the weight and bias of the
|
||||
LayerNorm are required to allow it to do this.
|
||||
|
||||
So the idea is to introduce this large constant value as an explicit
|
||||
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
||||
doesn't have to do this trick. We make the "eps" learnable.
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interprted as an offset from the input's ndim if negative.
|
||||
shis is NOT the num_channels; it should typically be one of
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True) -> None:
|
||||
super(BasicNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer('eps', torch.tensor(eps).log().detach())
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
|
||||
self.eps.exp()) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
|
||||
|
||||
class ScaledLinear(nn.Linear):
|
||||
"""
|
||||
A modified version of nn.Linear where the parameters are scaled before
|
||||
use, via:
|
||||
weight = self.weight * self.weight_scale.exp()
|
||||
bias = self.bias * self.bias_scale.exp()
|
||||
|
||||
Args:
|
||||
Accepts the standard args and kwargs that nn.Linear accepts
|
||||
e.g. in_features, out_features, bias=False.
|
||||
|
||||
initial_scale: you can override this if you want to increase
|
||||
or decrease the initial magnitude of the module's output
|
||||
(affects the initialization of weight_scale and bias_scale).
|
||||
Another option, if you want to do something like this, is
|
||||
to re-initialize the parameters.
|
||||
|
||||
Note: it uses the default initialization for the weight and bias,
|
||||
inherited from nn.Linear. For modules with small fan-in, this
|
||||
may be larger than optimal.
|
||||
"""
|
||||
def __init__(self, *args,
|
||||
initial_scale: float = 1.0,
|
||||
**kwargs):
|
||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
|
||||
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.01
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
if self.bias is not None:
|
||||
self.bias_scale += torch.tensor(scale / std).log()
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * self.bias_scale.exp())
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return torch.nn.functional.linear(input, self.get_weight(),
|
||||
self.get_bias())
|
||||
|
||||
|
||||
class ScaledConv1d(nn.Conv1d):
|
||||
def __init__(self, *args,
|
||||
initial_scale=1.0, **kwargs):
|
||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.01
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
if self.bias is not None:
|
||||
self.bias_scale += torch.tensor(scale / std).log()
|
||||
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * self.bias_scale.exp())
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
self.get_weight(), self.get_bias(), self.stride,
|
||||
_single(0), self.dilation, self.groups)
|
||||
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
|
||||
class ScaledConv2d(nn.Conv2d):
|
||||
def __init__(self, *args, initial_scale=1.0, **kwargs):
|
||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.01
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
if self.bias is not None:
|
||||
self.bias_scale += torch.tensor(scale / std).log()
|
||||
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * self.bias_scale.exp())
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.get_bias(), self.stride,
|
||||
_pair(0), self.dilation, self.groups)
|
||||
return F.conv2d(input, weight, self.get_bias(), self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self._conv_forward(input, self.get_weight())
|
||||
|
||||
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||
each channel, that it is positive at least a proportion `threshold` of the
|
||||
time. It does this by multiplying negative derivative values by up to
|
||||
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
||||
interpolated from 1 at the threshold to those extremal values when none
|
||||
of the inputs are positive.
|
||||
|
||||
|
||||
Args:
|
||||
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||
min_positive: the minimum, per channel, of the proportion of the time
|
||||
that (x > 0), below which we start to modify the derivatives.
|
||||
max_positive: the maximum, per channel, of the proportion of the time
|
||||
that (x > 0), below which we start to modify the derivatives.
|
||||
max_factor: the maximum factor by which we modify the derivatives for
|
||||
either the sign constraint or the magnitude constraint;
|
||||
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
||||
values in the range [0.98..1.02].
|
||||
min_abs: the minimum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
max_abs: the maximum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
"""
|
||||
def __init__(self, channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
self.max_positive = max_positive
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return ActivationBalancerFunction.apply(x, self.channel_dim,
|
||||
self.min_positive, self.max_positive,
|
||||
self.max_factor, self.min_abs,
|
||||
self.max_abs)
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
"""
|
||||
double_swish(x) = x * torch.sigmoid(x-1)
|
||||
This is a definition, originally motivated by its close numerical
|
||||
similarity to swish(swish(x), where swish(x) = x * sigmoid(x).
|
||||
|
||||
Memory-efficient derivative computation:
|
||||
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
||||
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
||||
Now, s'(x) = s(x) * (1-s(x)).
|
||||
double_swish'(x) = x * s'(x) + s(x).
|
||||
= x * s(x) * (1-s(x)) + s(x).
|
||||
= double_swish(x) * (1-s(x)) + s(x)
|
||||
... so we just need to remember s(x) but not x itself.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
x = x.detach()
|
||||
s = torch.sigmoid(x - 1.0)
|
||||
y = x * s
|
||||
ctx.save_for_backward(s, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
s, y = ctx.saved_tensors
|
||||
return (y * (1-s) + s) * y_grad
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
||||
This module is often used to store word embeddings and retrieve them using indices.
|
||||
The input to the module is a list of indices, and the output is the corresponding
|
||||
word embeddings.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): size of the dictionary of embeddings
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
||||
(initialized to zeros) whenever it encounters the index.
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
||||
is renormalized to have norm :attr:`max_norm`.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
||||
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
||||
the words in the mini-batch. Default ``False``.
|
||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||
See Notes for more details regarding sparse gradients.
|
||||
|
||||
Attributes:
|
||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||
initialized from :math:`\mathcal{N}(0, 1)`
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
||||
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
||||
|
||||
.. note::
|
||||
Keep in mind that only a limited number of optimizers support
|
||||
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
||||
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
||||
|
||||
.. note::
|
||||
With :attr:`padding_idx` set, the embedding vector at
|
||||
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
||||
vector can be modified afterwards, e.g., using a customized
|
||||
initialization method, and thus changing the vector used to pad the
|
||||
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
||||
is always zero.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # an Embedding module containing 10 tensors of size 3
|
||||
>>> embedding = nn.Embedding(10, 3)
|
||||
>>> # a batch of 2 samples of 4 indices each
|
||||
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
||||
>>> embedding(input)
|
||||
tensor([[[-0.0251, -1.6902, 0.7172],
|
||||
[-0.6431, 0.0748, 0.6969],
|
||||
[ 1.4970, 1.3448, -0.9685],
|
||||
[-0.3677, -2.7265, -0.1685]],
|
||||
|
||||
[[ 1.4970, 1.3448, -0.9685],
|
||||
[ 0.4362, -0.4004, 0.9400],
|
||||
[-0.6431, 0.0748, 0.6969],
|
||||
[ 0.9124, -2.3616, 1.1151]]])
|
||||
|
||||
|
||||
>>> # example with padding_idx
|
||||
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
||||
>>> input = torch.LongTensor([[0,2,0,5]])
|
||||
>>> embedding(input)
|
||||
tensor([[[ 0.0000, 0.0000, 0.0000],
|
||||
[ 0.1535, -2.0309, 0.9315],
|
||||
[ 0.0000, 0.0000, 0.0000],
|
||||
[-0.1655, 0.9897, 0.0635]]])
|
||||
"""
|
||||
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
||||
'scale_grad_by_freq', 'sparse']
|
||||
|
||||
num_embeddings: int
|
||||
embedding_dim: int
|
||||
padding_idx: int
|
||||
scale_grad_by_freq: bool
|
||||
weight: Tensor
|
||||
sparse: bool
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> None:
|
||||
super(ScaledEmbedding, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
elif padding_idx < 0:
|
||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
|
||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
||||
self.sparse = sparse
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.reset_parameters()
|
||||
|
||||
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
std = 0.01
|
||||
nn.init.normal_(self.weight, std=std)
|
||||
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
||||
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
scale = self.scale.exp()
|
||||
if input.numel() < self.num_embeddings:
|
||||
return F.embedding(
|
||||
input, self.weight, self.padding_idx,
|
||||
None, 2.0, # None, 2.0 relate to normalization
|
||||
self.scale_grad_by_freq, self.sparse) * scale
|
||||
else:
|
||||
return F.embedding(
|
||||
input, self.weight * scale, self.padding_idx,
|
||||
None, 2.0, # None, 2.0 relates to normalization
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = '{num_embeddings}, {embedding_dim}, scale={scale}'
|
||||
if self.padding_idx is not None:
|
||||
s += ', padding_idx={padding_idx}'
|
||||
if self.scale_grad_by_freq is not False:
|
||||
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
||||
if self.sparse is not False:
|
||||
s += ', sparse=True'
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
def _test_activation_balancer_sign():
|
||||
channel_dim = 0
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
||||
max_factor=0.2, min_abs=0.0)
|
||||
|
||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||
|
||||
y = m(x)
|
||||
y.backward(gradient=y_grad)
|
||||
print("_test_activation_balancer_sign: x = ", x)
|
||||
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
||||
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
||||
|
||||
def _test_activation_balancer_magnitude():
|
||||
channel_dim = 0
|
||||
magnitudes = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(channel_dim=0,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
max_factor=0.2,
|
||||
min_abs=0.2, max_abs=0.8)
|
||||
|
||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||
|
||||
y = m(x)
|
||||
y.backward(gradient=y_grad)
|
||||
print("_test_activation_balancer_magnitude: x = ", x)
|
||||
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
||||
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
||||
|
||||
|
||||
def _test_basic_norm():
|
||||
num_channels = 128
|
||||
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
||||
|
||||
x = torch.randn(500, num_channels)
|
||||
|
||||
y = m(x)
|
||||
|
||||
assert y.shape == x.shape
|
||||
x_rms = (x**2).mean().sqrt()
|
||||
y_rms = (y**2).mean().sqrt()
|
||||
print("x rms = ", x_rms)
|
||||
print("y rms = ", y_rms)
|
||||
assert y_rms < x_rms
|
||||
assert y_rms > 0.5 * x_rms
|
||||
|
||||
|
||||
def _test_double_swish_deriv():
|
||||
x = torch.randn(10, 12, dtype=torch.double) * 0.5
|
||||
x.requires_grad = True
|
||||
m = DoubleSwish()
|
||||
torch.autograd.gradcheck(m, x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
_test_double_swish_deriv()
|
||||
997
egs/librispeech/ASR/pruned2_knowledge/train.py
Executable file
997
egs/librispeech/ASR/pruned2_knowledge/train.py
Executable file
@ -0,0 +1,997 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned2_knowledge/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned2_knowledge/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned2_knowledge/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--use_fp16 1 \
|
||||
--exp-dir pruned2_knowledge/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned2_knowledge/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need to be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||
We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
type=float,
|
||||
default=6,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prune-range",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||
"we are using to compute the loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="The scale to smooth the loss with lm "
|
||||
"(output of prediction network) part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simple-loss-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="To get pruning ranges, we will calculate a simple version"
|
||||
"loss(joiner is just addition), this simple loss also uses for"
|
||||
"training (as a regularization item). We will scale the simple loss"
|
||||
"with this parameter before adding to the final loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- encoder_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"encoder_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 18,
|
||||
# parameters for decoder
|
||||
"decoder_dim": 512,
|
||||
# parameters for joiner
|
||||
"joiner_dim": 512,
|
||||
# parameters for Noam
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.encoder_dim,
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
decoder_dim=params.decoder_dim,
|
||||
blank_id=params.blank_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
encoder_dim=params.encoder_dim,
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=params.encoder_dim,
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The scheduler that we are using.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 0:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
if params.start_batch > 0:
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer used in the training.
|
||||
sampler:
|
||||
The sampler for the training dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
Args:
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of Conformer in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
)
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler, we call step() every step.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}"
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 1600
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank])
|
||||
model.device = device
|
||||
|
||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
if (
|
||||
checkpoints
|
||||
and "scheduler" in checkpoints
|
||||
and checkpoints["scheduler"] is not None
|
||||
):
|
||||
logging.info("Loading scheduler state dict")
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
# saved in the middle of an epoch
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
scheduler.step_epoch(epoch)
|
||||
fix_random_seed(params.seed + epoch)
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
logging.info(
|
||||
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
||||
)
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=0.0,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
"max_duration setting. We recommend decreasing "
|
||||
"max_duration and trying again.\n"
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -537,7 +537,7 @@ def greedy_search(
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -646,7 +646,7 @@ def greedy_search_batch(
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
|
||||
|
||||
# timestamp[n][i] is the frame index after subsampling
|
||||
# on which hyp[n][i] is decoded
|
||||
|
||||
@ -1621,6 +1621,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
|
||||
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/asr_datamodule.py
|
||||
1
egs/librispeech/ASR/pruned_transducer_stateless7/beam_search.py
Symbolic link
1
egs/librispeech/ASR/pruned_transducer_stateless7/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
||||
854
egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
Executable file
854
egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
Executable file
@ -0,0 +1,854 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
104
egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py
Normal file
104
egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
if context_size > 1:
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim//4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
||||
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/encoder_interface.py
|
||||
324
egs/librispeech/ASR/pruned_transducer_stateless7/export.py
Executable file
324
egs/librispeech/ASR/pruned_transducer_stateless7/export.py
Executable file
@ -0,0 +1,324 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.script()")
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
274
egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
Executable file
274
egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
Executable file
@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7/jit_pretrained.py \
|
||||
--nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
67
egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py
Normal file
67
egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
|
||||
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
decoder_out
|
||||
)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return logit
|
||||
195
egs/librispeech/ASR/pruned_transducer_stateless7/model.py
Normal file
195
egs/librispeech/ASR/pruned_transducer_stateless7/model.py
Normal file
@ -0,0 +1,195 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = nn.Linear(
|
||||
encoder_dim, vocab_size,
|
||||
)
|
||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
#if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
#if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
971
egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
Normal file
971
egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
Normal file
@ -0,0 +1,971 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Union, Tuple, List
|
||||
from lhotse.utils import fix_random_seed
|
||||
import torch
|
||||
from scaling import ActivationBalancer
|
||||
import random
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
import logging
|
||||
import contextlib
|
||||
|
||||
|
||||
|
||||
class BatchedOptimizer(Optimizer):
|
||||
"""
|
||||
This class adds to class Optimizer the capability to optimize parameters in batches:
|
||||
it will stack the parameters and their grads for you so the optimizer can work
|
||||
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
||||
as it reduces the number of kernels launched in the optimizer.
|
||||
|
||||
Args:
|
||||
params:
|
||||
"""
|
||||
def __init__(self, params, defaults):
|
||||
super(BatchedOptimizer, self).__init__(params, defaults)
|
||||
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def batched_params(self, param_group):
|
||||
"""
|
||||
This function returns (technically, yields) a list of
|
||||
of tuples (p, state), where
|
||||
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
||||
that share the same shape, and its gradient is also stacked;
|
||||
`state` is the state corresponding to this batch of parameters
|
||||
(it will be physically located in the "state" for one of the real
|
||||
parameters, the last one that has any particular shape and dtype).
|
||||
|
||||
This function is decorated as a context manager so that it can
|
||||
write parameters back to their "real" locations.
|
||||
|
||||
The idea is, instead of doing:
|
||||
<code>
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
...
|
||||
</code>
|
||||
you can do:
|
||||
<code>
|
||||
with self.batched_params(group["params"]) as batches:
|
||||
for p, state in batches:
|
||||
...
|
||||
</code>
|
||||
|
||||
Args:
|
||||
group: a parameter group, which is a list of parameters; should be
|
||||
one of self.groups.
|
||||
"""
|
||||
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
|
||||
for p in param_group:
|
||||
key = (str(p.dtype), *p.shape)
|
||||
batches[key].append(p)
|
||||
|
||||
stacked_params_dict = dict()
|
||||
|
||||
# turn batches into a list, in deterministic order.
|
||||
batches = [ batches[key] for key in sorted(batches.keys()) ]
|
||||
# pairs will contain pairs of (stacked_param, state), one for each batch
|
||||
# in `batches`.
|
||||
pairs = []
|
||||
|
||||
for batch in batches:
|
||||
p = batch[0]
|
||||
# we arbitrarily store the state in the
|
||||
# state corresponding to the 1st parameter in the
|
||||
# group. class Optimizer will take care of saving/loading state.
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
|
||||
p_stacked.grad = grad
|
||||
stacked_params_dict[key] = p_stacked
|
||||
pairs.append((p_stacked, state))
|
||||
|
||||
yield pairs # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state), batch) in zip(pairs, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
|
||||
|
||||
class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
||||
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
||||
param = underlying_param * log_scale.exp())
|
||||
|
||||
|
||||
Args:
|
||||
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||
at 0.03 and decreases over time, i.e. much higher than other common
|
||||
optimizers.
|
||||
clipping_scale: (e.g. 2.0)
|
||||
A scale for gradient-clipping: if specified, the normalized gradients
|
||||
over the whole model will be clipped to have 2-norm equal to
|
||||
`clipping_scale` times the median 2-norm over the most recent period
|
||||
of `clipping_update_period` minibatches. By "normalized gradients",
|
||||
we mean after multiplying by the rms parameter value for this tensor
|
||||
[for non-scalars]; this is appropriate because our update is scaled
|
||||
by this quantity.
|
||||
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
||||
Must satisfy 0 < beta <= beta2 < 1.
|
||||
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
||||
scale of each parameter tensor and scalar parameters of the mode..
|
||||
If each parameter were decomposed
|
||||
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
||||
would be a the scaling factor on the learning rate of p_scale.
|
||||
eps: A general-purpose epsilon to prevent division by zero
|
||||
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
||||
parameter tensor to be >= this value)
|
||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
||||
parameter tensor to be <= this value)
|
||||
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
||||
model has any parameters with numel() == 1).
|
||||
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
||||
of the parameter tensor. This is provided to save a little time
|
||||
in the update.
|
||||
clipping_update_period: if clipping_scale is specified, this is the period
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
clipping_scale=None,
|
||||
betas=(0.9, 0.98),
|
||||
scalar_lr_scale=0.1,
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=3.0,
|
||||
scalar_max=10.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
):
|
||||
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
betas=betas,
|
||||
scalar_lr_scale=scalar_lr_scale,
|
||||
eps=eps,
|
||||
param_min_rms=param_min_rms,
|
||||
param_max_rms=param_max_rms,
|
||||
scalar_max=scalar_max,
|
||||
size_update_period=size_update_period,
|
||||
clipping_update_period=clipping_update_period,
|
||||
)
|
||||
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(ScaledAdam, self).__setstate__(state)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
batch = True
|
||||
for group in self.param_groups:
|
||||
|
||||
with self.batched_params(group["params"]) as batches:
|
||||
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
|
||||
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||
clipping_scale = 1
|
||||
else:
|
||||
clipping_scale = self._get_clipping_scale(group, batches)
|
||||
|
||||
for p, state in batches:
|
||||
# Perform optimization step.
|
||||
# grad is not going to be None, we handled that when creating the batches.
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
self._init_state(group, p, state)
|
||||
|
||||
self._step_one_batch(group, p, state, clipping_scale)
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def _init_state(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
"""
|
||||
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
||||
is actually the batch dimension, corresponding to batched-together
|
||||
parameters of a given shape.
|
||||
|
||||
|
||||
Args:
|
||||
group: Dict to look up configuration values.
|
||||
p: The parameter that we are initializing the state for
|
||||
state: Dict from string to whatever state we are initializing
|
||||
"""
|
||||
size_update_period = group["size_update_period"]
|
||||
|
||||
state["step"] = 0
|
||||
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
|
||||
# 'delta' implements conventional momentum. There are
|
||||
# several different kinds of update going on, so rather than
|
||||
# compute "exp_avg" like in Adam, we store and decay a
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
numel = p.numel()
|
||||
|
||||
|
||||
if numel > 1:
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
|
||||
keepdim=True).sqrt()
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
||||
**kwargs)
|
||||
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
def _get_clipping_scale(self,
|
||||
group: dict,
|
||||
pairs: List[Tuple[Tensor, dict]]) -> float:
|
||||
"""
|
||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||
by this amount before applying the rest of the update.
|
||||
|
||||
Args:
|
||||
group: the parameter group, an item in self.param_groups
|
||||
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
|
||||
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
|
||||
"""
|
||||
assert len(pairs) >= 1
|
||||
clipping_scale = group["clipping_scale"]
|
||||
(first_p, first_state) = pairs[0]
|
||||
step = first_state["step"]
|
||||
if clipping_scale is None or step == 0:
|
||||
# no clipping. return early on step == 0 because the other
|
||||
# parameters' state won't have been initialized yet.
|
||||
return 1.0
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state) in pairs:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += ((grad * state["param_rms"])**2).sum()
|
||||
|
||||
tot_norm = tot_sumsq.sqrt()
|
||||
if not "model_norms" in first_state:
|
||||
first_state["model_norms"] = torch.zeros(clipping_update_period,
|
||||
device=p.device)
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
# Print some stats.
|
||||
# We don't reach here if step == 0 because we would have returned
|
||||
# above.
|
||||
sorted_norms = first_state["model_norms"].sort()[0].to('cpu')
|
||||
quartiles = []
|
||||
for n in range(0, 5):
|
||||
index = min(clipping_update_period - 1,
|
||||
(clipping_update_period // 4) * n)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
median = quartiles[2]
|
||||
threshold = clipping_scale * median
|
||||
first_state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period
|
||||
if "num_clipped" in first_state else 0.0)
|
||||
first_state["num_clipped"] = 0
|
||||
quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
|
||||
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
|
||||
|
||||
if step < clipping_update_period:
|
||||
return 1.0 # We have not yet estimated a norm to clip to.
|
||||
else:
|
||||
try:
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
except:
|
||||
logging.info("Warning: model_norm_threshold not in state: possibly "
|
||||
"you changed config when restarting, adding clipping_scale option?")
|
||||
return 1.0
|
||||
ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
return ans
|
||||
|
||||
|
||||
def _step_one_batch(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
clipping_scale: float):
|
||||
"""
|
||||
Do the step for one parameter, which is actually going to be a batch of
|
||||
`real` parameters, with dim 0 as the batch dim.
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
p: parameter to update (actually multiple parameters stacked together
|
||||
as a batch)
|
||||
state: state-dict for p, to look up the optimizer state
|
||||
"""
|
||||
lr = group["lr"]
|
||||
size_update_period = group["size_update_period"]
|
||||
beta1 = group["betas"][0]
|
||||
|
||||
grad = p.grad
|
||||
if clipping_scale != 1.0:
|
||||
grad = grad * clipping_scale
|
||||
step = state["step"]
|
||||
delta = state["delta"]
|
||||
|
||||
delta.mul_(beta1)
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
if numel > 1:
|
||||
# Update the size/scale of p, and set param_rms
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||
dim=list(range(1, p.ndim)), keepdim=True)
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
|
||||
keepdim=True).sqrt())
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
# parameter, by shrinking or expanding it.
|
||||
self._size_update(group, scale_grads, p, state)
|
||||
|
||||
|
||||
if numel == 1:
|
||||
# For parameters with 1 element we just use regular Adam.
|
||||
# Updates delta.
|
||||
self._step_scalar(group, p, state)
|
||||
else:
|
||||
self._step(group, p, state)
|
||||
|
||||
state["step"] = step + 1
|
||||
|
||||
|
||||
def _size_update(self,
|
||||
group: dict,
|
||||
scale_grads: Tensor,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||
gradient descent on underlying param and on scale, this function does the update
|
||||
on `scale`.
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
||||
grads w.r.t. the scales.
|
||||
p: The parameter to update
|
||||
state: The state-dict of p
|
||||
"""
|
||||
|
||||
param_rms = state["param_rms"]
|
||||
beta1, beta2 = group["betas"]
|
||||
size_lr = group["lr"] * group["scalar_lr_scale"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
eps = group["eps"]
|
||||
step = state["step"]
|
||||
batch_size = p.shape[0]
|
||||
|
||||
size_update_period = scale_grads.shape[0]
|
||||
# correct beta2 for the size update period: we will have
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2 ** size_update_period
|
||||
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
|
||||
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
bias_correction2 = 1 - beta2_corr ** size_step
|
||||
# we don't bother with bias_correction1; this will help prevent divergence
|
||||
# at the start of training.
|
||||
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
|
||||
|
||||
is_too_small = (param_rms < param_min_rms)
|
||||
is_too_large = (param_rms > param_max_rms)
|
||||
|
||||
# when the param gets too small, just don't shrink it any further.
|
||||
scale_step.masked_fill_(is_too_small, 0.0)
|
||||
# when it gets too large, stop it from getting any larger.
|
||||
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
delta.add_(p * scale_step, alpha=(1-beta1))
|
||||
|
||||
|
||||
def _step(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
"""
|
||||
This function does the core update of self.step(), in the case where the members of
|
||||
the batch have more than 1 element.
|
||||
|
||||
Args:
|
||||
group: A dict which will be used to look up configuration values
|
||||
p: The parameter to be updated
|
||||
grad: The grad of p
|
||||
state: The state-dict corresponding to parameter p
|
||||
|
||||
This function modifies p.
|
||||
"""
|
||||
grad = p.grad
|
||||
lr = group["lr"]
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
step = state["step"]
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=(1-beta2))
|
||||
|
||||
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||
|
||||
denom = exp_avg_sq.sqrt()
|
||||
denom += eps
|
||||
grad = grad / denom
|
||||
|
||||
alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad * alpha)
|
||||
p.add_(delta)
|
||||
|
||||
|
||||
def _step_scalar(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
"""
|
||||
A simplified form of the core update for scalar tensors, where we cannot get a good
|
||||
estimate of the parameter rms.
|
||||
"""
|
||||
beta1, beta2 = group["betas"]
|
||||
scalar_max = group["scalar_max"]
|
||||
eps = group["eps"]
|
||||
lr = group["lr"] * group["scalar_lr_scale"]
|
||||
grad = p.grad
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=1-beta2)
|
||||
|
||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||
# slower update at the start will help stability anyway.
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad / denom, alpha=-lr*(1-beta1))
|
||||
p.clamp_(min=-scalar_max, max=scalar_max)
|
||||
p.add_(delta)
|
||||
|
||||
|
||||
|
||||
class LRScheduler(object):
|
||||
"""
|
||||
Base-class for learning rate schedulers where the learning-rate depends on both the
|
||||
batch and the epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
||||
# Attach optimizer
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError(
|
||||
"{} is not an Optimizer".format(type(optimizer).__name__)
|
||||
)
|
||||
self.optimizer = optimizer
|
||||
self.verbose = verbose
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("base_lr", group["lr"])
|
||||
|
||||
self.base_lrs = [
|
||||
group["base_lr"] for group in optimizer.param_groups
|
||||
]
|
||||
|
||||
self.epoch = 0
|
||||
self.batch = 0
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a :class:`dict`.
|
||||
|
||||
It contains an entry for every variable in self.__dict__ which
|
||||
is not the optimizer.
|
||||
"""
|
||||
return {
|
||||
"base_lrs": self.base_lrs,
|
||||
"epoch": self.epoch,
|
||||
"batch": self.batch,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the schedulers state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): scheduler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_last_lr(self) -> List[float]:
|
||||
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
||||
return self._last_lr
|
||||
|
||||
def get_lr(self):
|
||||
# Compute list of learning rates from self.epoch and self.batch and
|
||||
# self.base_lrs; this must be overloaded by the user.
|
||||
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
||||
raise NotImplementedError
|
||||
|
||||
def step_batch(self, batch: Optional[int] = None) -> None:
|
||||
# Step the batch index, or just set it. If `batch` is specified, it
|
||||
# must be the batch index from the start of training, i.e. summed over
|
||||
# all epochs.
|
||||
# You can call this in any order; if you don't provide 'batch', it should
|
||||
# of course be called once per batch.
|
||||
if batch is not None:
|
||||
self.batch = batch
|
||||
else:
|
||||
self.batch = self.batch + 1
|
||||
self._set_lrs()
|
||||
|
||||
def step_epoch(self, epoch: Optional[int] = None):
|
||||
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
||||
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
||||
# arg, you should call it at the end of the epoch.
|
||||
if epoch is not None:
|
||||
self.epoch = epoch
|
||||
else:
|
||||
self.epoch = self.epoch + 1
|
||||
self._set_lrs()
|
||||
|
||||
def _set_lrs(self):
|
||||
values = self.get_lr()
|
||||
assert len(values) == len(self.optimizer.param_groups)
|
||||
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||
param_group, lr = data
|
||||
param_group["lr"] = lr
|
||||
self.print_lr(self.verbose, i, lr)
|
||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
def print_lr(self, is_verbose, group, lr):
|
||||
"""Display the current learning rate."""
|
||||
if is_verbose:
|
||||
logging.info(
|
||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||
f" of group {group} to {lr:.4e}."
|
||||
)
|
||||
|
||||
|
||||
class Eden(LRScheduler):
|
||||
"""
|
||||
Eden scheduler.
|
||||
The basic formula (before warmup) is:
|
||||
lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
||||
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
|
||||
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
||||
and then stays constant at 1.
|
||||
|
||||
|
||||
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer to change the learning rates on
|
||||
lr_batches: the number of batches after which we start significantly
|
||||
decreasing the learning rate, suggest 5000.
|
||||
lr_epochs: the number of epochs after which we start significantly
|
||||
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
||||
20 to 40 epochs, but may need smaller number if dataset is huge
|
||||
and you will do few epochs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
lr_batches: Union[int, float],
|
||||
lr_epochs: Union[int, float],
|
||||
warmup_batches: Union[int, float] = 500.0,
|
||||
verbose: bool = False,
|
||||
):
|
||||
super(Eden, self).__init__(optimizer, verbose)
|
||||
self.lr_batches = lr_batches
|
||||
self.lr_epochs = lr_epochs
|
||||
self.warmup_batches = warmup_batches
|
||||
|
||||
def get_lr(self):
|
||||
factor = (
|
||||
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
|
||||
) ** -0.25 * (
|
||||
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
||||
** -0.25
|
||||
)
|
||||
warmup_factor = (1.0 if self.batch >= self.warmup_batches
|
||||
else 0.5 + 0.5 * (self.batch / self.warmup_batches))
|
||||
|
||||
return [x * factor * warmup_factor for x in self.base_lrs]
|
||||
|
||||
|
||||
def _test_eden():
|
||||
m = torch.nn.Linear(100, 100)
|
||||
optim = ScaledAdam(m.parameters(), lr=0.03)
|
||||
|
||||
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
||||
|
||||
for epoch in range(10):
|
||||
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
||||
|
||||
for step in range(20):
|
||||
x = torch.randn(200, 100).detach()
|
||||
x.requires_grad = True
|
||||
y = m(x)
|
||||
dy = torch.randn(200, 100).detach()
|
||||
f = (y * dy).sum()
|
||||
f.backward()
|
||||
|
||||
optim.step()
|
||||
scheduler.step_batch()
|
||||
optim.zero_grad()
|
||||
|
||||
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
||||
logging.info(f"state dict = {scheduler.state_dict()}")
|
||||
|
||||
|
||||
# This is included mostly as a baseline for ScaledAdam.
|
||||
class Eve(Optimizer):
|
||||
"""
|
||||
Implements Eve algorithm. This is a modified version of AdamW with a special
|
||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
||||
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
||||
will be close to invariant to the absolute scale on the parameter matrix.
|
||||
|
||||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
||||
Eve is unpublished so far.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
||||
this value means that the weight would decay significantly after
|
||||
about 3k minibatches. Is not multiplied by learning rate, but
|
||||
is conditional on RMS-value of parameter being > target_rms.
|
||||
target_rms (float, optional): target root-mean-square value of
|
||||
parameters, if they fall below this we will stop applying weight decay.
|
||||
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-3,
|
||||
target_rms=0.1,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 0: {}".format(betas[0])
|
||||
)
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
if not 0 <= weight_decay <= 0.1:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
)
|
||||
if not 0 < target_rms <= 10.0:
|
||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
target_rms=target_rms,
|
||||
)
|
||||
super(Eve, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Eve, self).__setstate__(state)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
# Perform optimization step
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"AdamW does not support sparse gradients"
|
||||
)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
||||
group["eps"]
|
||||
)
|
||||
|
||||
step_size = group["lr"] / bias_correction1
|
||||
target_rms = group["target_rms"]
|
||||
weight_decay = group["weight_decay"]
|
||||
|
||||
if p.numel() > 1:
|
||||
# avoid applying this weight-decay on "scaling factors"
|
||||
# (which are scalar).
|
||||
is_above_target_rms = p.norm() > (
|
||||
target_rms * (p.numel() ** 0.5)
|
||||
)
|
||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
if random.random() < 0.0005:
|
||||
step = (exp_avg/denom) * step_size
|
||||
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _test_scaled_adam(hidden_dim: int):
|
||||
import timeit
|
||||
from scaling import ScaledLinear
|
||||
E = 100
|
||||
B = 4
|
||||
T = 2
|
||||
logging.info("in test_eve_cain")
|
||||
#device = torch.device('cuda')
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
fix_random_seed(42)
|
||||
# these input_magnitudes and output_magnitudes are to test that
|
||||
# Abel is working as we expect and is able to adjust scales of
|
||||
# different dims differently.
|
||||
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
|
||||
for iter in [1, 0]:
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
|
||||
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
||||
torch.nn.PReLU(),
|
||||
Linear(hidden_dim, hidden_dim),
|
||||
torch.nn.PReLU(),
|
||||
Linear(hidden_dim, E),
|
||||
).to(device)
|
||||
|
||||
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
||||
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
|
||||
start = timeit.default_timer()
|
||||
avg_loss = 0.0
|
||||
for epoch in range(180):
|
||||
scheduler.step_epoch()
|
||||
#if epoch == 100 and iter in [2,3]:
|
||||
# optim.reset_speedup() # check it doesn't crash.
|
||||
|
||||
#if epoch == 130:
|
||||
# opts = diagnostics.TensorDiagnosticOptions(
|
||||
# 2 ** 22
|
||||
# ) # allow 4 megabytes per sub-module
|
||||
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||
|
||||
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
if epoch == 0 and n == 0:
|
||||
avg_loss = loss.item()
|
||||
else:
|
||||
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
||||
if n == 0 and epoch % 5 == 0:
|
||||
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
#norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||
#norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
||||
#scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
||||
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
loss.log().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
scheduler.step_batch()
|
||||
|
||||
#diagnostic.print_diagnostics()
|
||||
|
||||
stop = timeit.default_timer()
|
||||
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
||||
|
||||
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
||||
#logging.info("state dict = ", scheduler.state_dict())
|
||||
#logging.info("optim state_dict = ", optim.state_dict())
|
||||
logging.info(f"input_magnitudes = {input_magnitudes}")
|
||||
logging.info(f"output_magnitudes = {output_magnitudes}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
import subprocess
|
||||
s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True)
|
||||
logging.info(s)
|
||||
import sys
|
||||
if len(sys.argv) > 1:
|
||||
hidden_dim = int(sys.argv[1])
|
||||
else:
|
||||
hidden_dim = 200
|
||||
|
||||
_test_scaled_adam(hidden_dim)
|
||||
_test_eden()
|
||||
363
egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
Executable file
363
egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
Executable file
@ -0,0 +1,363 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1161
egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
Normal file
1161
egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,118 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file replaces various modules in a model.
|
||||
Specifically, ActivationBalancer is replaced with an identity operator;
|
||||
Whiten is also replaced with an identity operator;
|
||||
BasicNorm is replaced by a module with `exp` removed.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
Whiten,
|
||||
)
|
||||
|
||||
|
||||
class NonScaledNorm(nn.Module):
|
||||
"""See BasicNorm for doc"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
eps_exp: float,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.eps_exp = eps_exp
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
|
||||
).pow(-0.5)
|
||||
return x * scales
|
||||
|
||||
|
||||
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
|
||||
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
|
||||
norm = NonScaledNorm(
|
||||
num_channels=basic_norm.num_channels,
|
||||
eps_exp=basic_norm.eps.data.exp().item(),
|
||||
channel_dim=basic_norm.channel_dim,
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||
# get_submodule was added to nn.Module at v1.9.0
|
||||
def get_submodule(model, target):
|
||||
if target == "":
|
||||
return model
|
||||
atoms: List[str] = target.split(".")
|
||||
mod: torch.nn.Module = model
|
||||
for item in atoms:
|
||||
if not hasattr(mod, item):
|
||||
raise AttributeError(
|
||||
mod._get_name() + " has no " "attribute `" + item + "`"
|
||||
)
|
||||
mod = getattr(mod, item)
|
||||
if not isinstance(mod, torch.nn.Module):
|
||||
raise AttributeError("`" + item + "` is not " "an nn.Module")
|
||||
return mod
|
||||
|
||||
|
||||
def convert_scaled_to_non_scaled(
|
||||
model: nn.Module,
|
||||
inplace: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
The model to be converted.
|
||||
inplace:
|
||||
If True, the input model is modified inplace.
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
d = {}
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, BasicNorm):
|
||||
d[name] = convert_basic_norm(m)
|
||||
elif isinstance(m, (ActivationBalancer, Whiten)):
|
||||
d[name] = nn.Identity()
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
parent, child = k.rsplit(".", maxsplit=1)
|
||||
setattr(get_submodule(model, parent), child, v)
|
||||
else:
|
||||
setattr(model, k, v)
|
||||
|
||||
return model
|
||||
56
egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
Executable file
56
egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
Executable file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
"""
|
||||
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model_1():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
# params.feedforward_dims = "1024,1024,1536,1536,1024"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
def main():
|
||||
test_model_1()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1217
egs/librispeech/ASR/pruned_transducer_stateless7/train.py
Executable file
1217
egs/librispeech/ASR/pruned_transducer_stateless7/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1858
egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
Normal file
1858
egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -86,7 +86,7 @@ def save_checkpoint(
|
||||
}
|
||||
|
||||
if model_avg is not None:
|
||||
checkpoint["model_avg"] = model_avg.state_dict()
|
||||
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
@ -466,8 +466,10 @@ def average_state_dict(
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
for k in uniqued_names:
|
||||
state_dict_1[k] *= weight_1
|
||||
state_dict_1[k] += (
|
||||
v = state_dict_1[k]
|
||||
if torch.is_floating_point(v):
|
||||
v *= weight_1
|
||||
v += (
|
||||
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
|
||||
)
|
||||
state_dict_1[k] *= scaling_factor
|
||||
v *= scaling_factor
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -82,10 +82,17 @@ def get_tensor_stats(
|
||||
elif stats_type == "positive":
|
||||
x = (x > 0).to(dtype=torch.float)
|
||||
else:
|
||||
assert stats_type == "value"
|
||||
assert stats_type in [ "value", "max", "min" ]
|
||||
|
||||
sum_dims = [d for d in range(x.ndim) if d != dim]
|
||||
if len(sum_dims) > 0:
|
||||
if stats_type == "max":
|
||||
for dim in reversed(sum_dims):
|
||||
x = torch.max(x, dim=dim)[0]
|
||||
elif stats_type == "min":
|
||||
for dim in reversed(sum_dims):
|
||||
x = torch.min(x, dim=dim)[0]
|
||||
else:
|
||||
x = torch.sum(x, dim=sum_dims)
|
||||
x = x.flatten()
|
||||
return x, count
|
||||
@ -105,17 +112,19 @@ class TensorDiagnostic(object):
|
||||
opts:
|
||||
Options object.
|
||||
name:
|
||||
The tensor name.
|
||||
The name associated with this diagnostics object, will probably be {module_name}.X
|
||||
where X is "output" or "grad", or {parameter_name}.Y where Y is param_value or param_grad.
|
||||
"""
|
||||
|
||||
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||
self.name = name
|
||||
self.opts = opts
|
||||
self.name = name
|
||||
self.class_name = None # will assign in accumulate()
|
||||
|
||||
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
||||
|
||||
# the keys into self.stats[dim] are strings, whose values can be
|
||||
# "abs", "value", "positive", "rms", "value".
|
||||
# "abs", "max", "min" ,"value", "positive", "rms", "value".
|
||||
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
|
||||
# containing a tensor and its associated count (which is the sum of the other dims
|
||||
# that we aggregated over, e.g. the number of frames and/or batch elements and/or
|
||||
@ -124,8 +133,13 @@ class TensorDiagnostic(object):
|
||||
# only adding a new element to the list if there was a different dim.
|
||||
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
||||
|
||||
def accumulate(self, x):
|
||||
"""Accumulate tensors."""
|
||||
|
||||
def accumulate(self, x, class_name: Optional[str] = None):
|
||||
"""
|
||||
Accumulate tensors.
|
||||
"""
|
||||
if class_name is not None:
|
||||
self.class_name = class_name
|
||||
if isinstance(x, Tuple):
|
||||
x = x[0]
|
||||
if not isinstance(x, Tensor):
|
||||
@ -142,11 +156,11 @@ class TensorDiagnostic(object):
|
||||
for dim in range(ndim):
|
||||
this_dim_stats = self.stats[dim]
|
||||
if ndim > 1:
|
||||
stats_types = ["abs", "positive", "value", "rms"]
|
||||
stats_types = ["abs", "max", "min", "positive", "value", "rms"]
|
||||
if x.shape[dim] <= self.opts.max_eig_dim:
|
||||
stats_types.append("eigs")
|
||||
else:
|
||||
stats_types = ["value", "abs"]
|
||||
stats_types = ["value", "abs", "max", "min"]
|
||||
|
||||
for stats_type in stats_types:
|
||||
stats, count = get_tensor_stats(x, dim, stats_type)
|
||||
@ -161,6 +175,11 @@ class TensorDiagnostic(object):
|
||||
continue
|
||||
for s in this_dim_stats[stats_type]:
|
||||
if s.tensor.shape == stats.shape:
|
||||
if stats_type == "max":
|
||||
s.tensor = torch.maximum(s.tensor, stats)
|
||||
elif stats_type == "min":
|
||||
s.tensor = torch.minimum(s.tensor, stats)
|
||||
else:
|
||||
s.tensor += stats
|
||||
s.count += count
|
||||
done = True
|
||||
@ -186,14 +205,26 @@ class TensorDiagnostic(object):
|
||||
for dim, this_dim_stats in enumerate(self.stats):
|
||||
for stats_type, stats_list in this_dim_stats.items():
|
||||
# stats_type could be "rms", "value", "abs", "eigs", "positive".
|
||||
# "value" could be a list of TensorAndCount, or None
|
||||
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
|
||||
# shape of the stats), or None
|
||||
if stats_list is None:
|
||||
assert stats_type == "eigs"
|
||||
continue
|
||||
|
||||
|
||||
def get_count(count):
|
||||
return 1 if stats_type in ["max", "min"] else count
|
||||
|
||||
if len(stats_list) == 1:
|
||||
stats = stats_list[0].tensor / get_count(stats_list[0].count)
|
||||
else:
|
||||
# a dimension that has variable size in different nnet
|
||||
# forwards, e.g. a time dimension in an ASR model.
|
||||
stats = torch.cat(
|
||||
[x.tensor / get_count(x.count) for x in stats_list], dim=0
|
||||
)
|
||||
|
||||
if stats_type == "eigs":
|
||||
assert len(stats_list) == 1
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
@ -201,15 +232,9 @@ class TensorDiagnostic(object):
|
||||
print(
|
||||
"Error getting eigenvalues, trying another method."
|
||||
)
|
||||
eigs = torch.linalg.eigvals(stats)
|
||||
eigs, _ = torch.eig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif len(stats_list) == 1:
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
else:
|
||||
stats = torch.cat(
|
||||
[x.tensor / x.count for x in stats_list], dim=0
|
||||
)
|
||||
|
||||
if stats_type == "rms":
|
||||
# we stored the square; after aggregation we need to take sqrt.
|
||||
@ -236,7 +261,7 @@ class TensorDiagnostic(object):
|
||||
ans = stats.tolist()
|
||||
ans = ["%.2g" % x for x in ans]
|
||||
ans = "[" + " ".join(ans) + "]"
|
||||
if stats_type == "value":
|
||||
if stats_type in [ "value", "rms", "eigs" ]:
|
||||
# This norm is useful because it is strictly less than the largest
|
||||
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
||||
# speaking in an approximate way, how much of that largest eigenvalue
|
||||
@ -245,7 +270,7 @@ class TensorDiagnostic(object):
|
||||
ans += f", norm={norm:.2g}"
|
||||
mean = stats.mean().item()
|
||||
rms = (stats ** 2).mean().sqrt().item()
|
||||
ans += f", mean={mean:.2g}, rms={rms:.2g}"
|
||||
ans += f", mean={mean:.3g}, rms={rms:.3g}"
|
||||
|
||||
# OK, "ans" contains the actual stats, e.g.
|
||||
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
|
||||
@ -256,11 +281,13 @@ class TensorDiagnostic(object):
|
||||
if len(sizes) == 1
|
||||
else f"{min(sizes)}..{max(sizes)}"
|
||||
)
|
||||
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
||||
print(
|
||||
f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}"
|
||||
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||
|
||||
@ -321,20 +348,29 @@ def attach_diagnostics(
|
||||
def forward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
|
||||
if isinstance(_output, Tensor):
|
||||
_model_diagnostic[f"{_name}.output"].accumulate(_output)
|
||||
_model_diagnostic[f"{_name}.output"].accumulate(_output,
|
||||
class_name=type(_module).__name__)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o)
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||
class_name=type(_module).__name__)
|
||||
|
||||
def backward_hook(
|
||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||
):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
if isinstance(_output, Tensor):
|
||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output)
|
||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
|
||||
class_name=type(_module).__name__)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o)
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||
class_name=type(_module).__name__)
|
||||
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
|
||||
102
icefall/hooks.py
Normal file
102
icefall/hooks.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright 2021-2022 Xiaomi Corporation (authors: Zengwei Yao, Daniel Povey)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import logging
|
||||
|
||||
|
||||
def register_inf_check_hooks(model: nn.Module) -> None:
|
||||
"""Registering forward hook on each module, to check
|
||||
whether its output tensors is not finite.
|
||||
|
||||
Args:
|
||||
model:
|
||||
the model to be analyzed.
|
||||
"""
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name == "":
|
||||
name = "<top-level>"
|
||||
|
||||
# default param _name is a way to capture the current value of the variable "name".
|
||||
def forward_hook(_module, _input, _output, _name=name):
|
||||
if isinstance(_output, Tensor):
|
||||
if not torch.isfinite(_output.to(torch.float32).sum()):
|
||||
raise ValueError(
|
||||
f"The sum of {_name}.output is not finite: {_output}"
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
if not isinstance(o, Tensor):
|
||||
continue
|
||||
if not torch.isfinite(o.to(torch.float32).sum()):
|
||||
raise ValueError(
|
||||
f"The sum of {_name}.output[{i}] is not finite: {_output}"
|
||||
)
|
||||
|
||||
# default param _name is a way to capture the current value of the variable "name".
|
||||
def backward_hook(_module, _input, _output, _name=name):
|
||||
if isinstance(_output, Tensor):
|
||||
if not torch.isfinite(_output.to(torch.float32).sum()):
|
||||
logging.warning(
|
||||
f"The sum of {_name}.grad is not finite" # ": {_output}"
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
if not isinstance(o, Tensor):
|
||||
continue
|
||||
if not torch.isfinite(o.to(torch.float32).sum()):
|
||||
logging.warning(
|
||||
f"The sum of {_name}.grad[{i}] is not finite"
|
||||
)
|
||||
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
|
||||
|
||||
for name, parameter in model.named_parameters():
|
||||
|
||||
def param_backward_hook(
|
||||
grad, _name=name
|
||||
):
|
||||
if not torch.isfinite(grad.to(torch.float32).sum()):
|
||||
logging.warning(
|
||||
f"The sum of {_name}.param_grad is not finite"
|
||||
)
|
||||
|
||||
parameter.register_hook(param_backward_hook)
|
||||
|
||||
|
||||
|
||||
def _test_inf_check_hooks():
|
||||
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
|
||||
|
||||
register_inf_check_hooks(model)
|
||||
for _ in range(10):
|
||||
T = random.randint(200, 300)
|
||||
x = torch.randn(T, 100) + float("inf") * (T % 2)
|
||||
y = model(x)
|
||||
y.sum().backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_inf_check_hooks()
|
||||
Loading…
x
Reference in New Issue
Block a user