Fix decode.py to remove the correct axis. (#50)

* Fix decode.py to remove the correct axis.

* Run GitHub actions manually.
This commit is contained in:
Fangjun Kuang 2021-09-17 16:49:03 +08:00 committed by GitHub
parent 9a6e0489c8
commit cc77cb3459
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 13 deletions

View File

@ -21,11 +21,11 @@ on:
branches: branches:
- master - master
pull_request: pull_request:
branches: types: [labeled]
- master
jobs: jobs:
run-yesno-recipe: run-yesno-recipe:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -33,6 +33,8 @@ jobs:
# TODO: enable macOS for CPU testing # TODO: enable macOS for CPU testing
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.8] python-version: [3.8]
torch: ["1.8.1"]
k2-version: ["1.8.dev20210917"]
fail-fast: false fail-fast: false
steps: steps:
@ -54,10 +56,8 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip python3 -m pip install -U pip
python3 -m pip install k2==1.7.dev20210914+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install torchaudio==0.7.2
python3 -m pip install git+https://github.com/lhotse-speech/lhotse python3 -m pip install git+https://github.com/lhotse-speech/lhotse
# We are in ./icefall and there is a file: requirements.txt in it # We are in ./icefall and there is a file: requirements.txt in it

View File

@ -21,18 +21,18 @@ on:
branches: branches:
- master - master
pull_request: pull_request:
branches: types: [labeled]
- master
jobs: jobs:
test: test:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
os: [ubuntu-18.04, macos-10.15] os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"] torch: ["1.8.1"]
k2-version: ["1.7.dev20210914"] k2-version: ["1.8.dev20210917"]
fail-fast: false fail-fast: false

View File

@ -222,7 +222,7 @@ def nbest_decoding(
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1) word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove 0 (epsilon) and -1 from word_seq # Remove 0 (epsilon) and -1 from word_seq
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
@ -432,7 +432,7 @@ def rescore_with_n_best_list(
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1) word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove epsilons and -1 from word_seq # Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
@ -669,7 +669,7 @@ def nbest_oracle(
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1) word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique( unique_word_seq, _, _ = word_seq.unique(
@ -761,7 +761,7 @@ def rescore_with_attention_decoder(
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1) word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove epsilons and -1 from word_seq # Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
@ -815,7 +815,7 @@ def rescore_with_attention_decoder(
token_seq = k2.ragged.index(lattice.tokens, path) token_seq = k2.ragged.index(lattice.tokens, path)
else: else:
token_seq = lattice.tokens.index(path) token_seq = lattice.tokens.index(path)
token_seq = token_seq.remove_axis(1) token_seq = token_seq.remove_axis(token_seq.num_axes - 2)
# Remove epsilons and -1 from token_seq # Remove epsilons and -1 from token_seq
token_seq = token_seq.remove_values_leq(0) token_seq = token_seq.remove_values_leq(0)