Merge remote-tracking branch 'k2-fsa/master'

This commit is contained in:
yaozengwei 2022-05-12 17:45:45 +08:00
commit bcef517a84
23 changed files with 496 additions and 120 deletions

View File

@ -0,0 +1,17 @@
#!/usr/bin/env bash
# This script computes fbank features for the test-clean and test-other datasets.
# The computed features are saved to ~/tmp/fbank-libri and are
# cached for later runs
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
mkdir ~/tmp/fbank-libri
cd egs/librispeech/ASR
mkdir -p data
cd data
[ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank
cd ..
./local/compute_fbank_librispeech.py
ls -lh data/fbank/

View File

@ -0,0 +1,23 @@
#!/usr/bin/env bash
# This script downloads the test-clean and test-other datasets
# of LibriSpeech and unzip them to the folder ~/tmp/download,
# which is cached by GitHub actions for later runs.
#
# You will find directories ~/tmp/download/LibriSpeech after running
# this script.
mkdir ~/tmp/download
cd egs/librispeech/ASR
ln -s ~/tmp/download .
cd download
wget -q --no-check-certificate https://www.openslr.org/resources/12/test-clean.tar.gz
tar xf test-clean.tar.gz
rm test-clean.tar.gz
wget -q --no-check-certificate https://www.openslr.org/resources/12/test-other.tar.gz
tar xf test-other.tar.gz
rm test-other.tar.gz
pwd
ls -lh
ls -lh LibriSpeech

13
.github/scripts/install-kaldifeat.sh vendored Executable file
View File

@ -0,0 +1,13 @@
#!/usr/bin/env bash
# This script installs kaldifeat into the directory ~/tmp/kaldifeat
# which is cached by GitHub actions for later runs.
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat

View File

@ -0,0 +1,11 @@
#!/usr/bin/env bash
# This script assumes that test-clean and test-other are downloaded
# to egs/librispeech/ASR/download/LibriSpeech and generates manifest
# files in egs/librispeech/ASR/data/manifests
cd egs/librispeech/ASR
[ ! -e download ] && ln -s ~/tmp/download .
mkdir -p data/manifests
lhotse prepare librispeech -j 2 -p test-clean -p test-other ./download/LibriSpeech data/manifests
ls -lh data/manifests

View File

@ -45,3 +45,31 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
mkdir -p pruned_transducer_stateless/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh pruned_transducer_stateless/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=50
for method in greedy_search fast_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless/exp
done
rm pruned_transducer_stateless/exp/*.pt
fi

View File

@ -49,3 +49,31 @@ for method in modified_beam_search beam_search fast_beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
mkdir -p pruned_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh pruned_transducer_stateless2/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=50
for method in greedy_search fast_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless2/exp
done
rm pruned_transducer_stateless2/exp/*.pt
fi

View File

@ -49,3 +49,31 @@ for method in modified_beam_search beam_search fast_beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
mkdir -p pruned_transducer_stateless3/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh pruned_transducer_stateless3/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=50
for method in greedy_search fast_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless3/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless3/exp
done
rm pruned_transducer_stateless3/exp/*.pt
fi

View File

@ -45,3 +45,31 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
mkdir -p transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh transducer_stateless2/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=50
for method in greedy_search modified_beam_search; do
log "Decoding with $method"
./transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir transducer_stateless2/exp
done
rm transducer_stateless2/exp/*.pt
fi

View File

@ -24,9 +24,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_2022_03_12:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -63,20 +72,78 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
- name: Display decoding results
if: github.event_name == 'schedule'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./pruned_transducer_stateless/exp
cd pruned_transducer_stateless
echo "results for pruned_transducer_stateless"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/

View File

@ -24,9 +24,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_2022_04_29:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -63,18 +72,50 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
@ -83,3 +124,45 @@ jobs:
.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
- name: Display decoding results
if: github.event_name == 'schedule'
shell: bash
run: |
cd egs/librispeech/ASR
tree pruned_transducer_stateless2/exp
cd pruned_transducer_stateless2
echo "results for pruned_transducer_stateless2"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
cd ../
tree pruned_transducer_stateless3/exp
cd pruned_transducer_stateless3
echo "results for pruned_transducer_stateless3"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
- name: Upload decoding results for pruned_transducer_stateless3
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/

View File

@ -24,9 +24,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_2022_04_19:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -63,20 +72,77 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
- name: Display decoding results
if: github.event_name == 'schedule'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./transducer_stateless2/exp
cd transducer_stateless2
echo "results for transducer_stateless2"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified_beam_search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
path: egs/librispeech/ASR/transducer_stateless2/exp/

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -62,13 +62,6 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Inference with pre-trained model

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -22,7 +23,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
@ -33,20 +34,19 @@ To use the generated file with `transducer_stateless/decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
cd /path/to/egs/aishell/ASR
./transducer_stateless/decode.py \
--exp-dir ./transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
@ -56,6 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
@ -91,10 +92,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
@ -194,12 +195,10 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
lexicon = Lexicon(params.lang_dir)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
@ -565,8 +566,10 @@ def modified_beam_search(
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
@ -679,8 +682,10 @@ def _deprecated_modified_beam_search(
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -540,23 +540,52 @@ def main():
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0 and params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
)
model.to(device)
model.eval()

View File

@ -1,5 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
@ -405,7 +405,7 @@ def average_checkpoints_with_averaged_model(
(3) avg = (model_end + model_start * (weight_start / weight_end))
* weight_end
The model index could be epoch number or checkpoint number.
The model index could be epoch number or iteration number.
Args:
filename_start:

View File

@ -95,7 +95,7 @@ def get_env_info() -> Dict[str, Any]:
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-version": torch.__version__,
"torch-version": str(torch.__version__),
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],