mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix shallow fusion and add CI tests for it (#676)
* Fix shallow fusion and add CI tests for it * Fix -1 index in embedding introduced in the zipformer PR
This commit is contained in:
parent
7e82f87126
commit
cedf9aa24f
@ -174,6 +174,36 @@ done
|
||||
|
||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||
|
||||
if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
|
||||
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
|
||||
git clone $lm_repo_url
|
||||
lm_repo=$(basename $lm_repo_url)
|
||||
pushd $lm_repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-88.pt
|
||||
popd
|
||||
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--lang-dir $repo/data/lang_bpe_500 \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
|
||||
--beam 4 \
|
||||
--rnn-lm-scale 0.3 \
|
||||
--rnn-lm-exp-dir $lm_repo/exp \
|
||||
--rnn-lm-epoch 88 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
fi
|
||||
|
||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
||||
mkdir -p lstm_transducer_stateless2/exp
|
||||
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
|
||||
|
@ -18,7 +18,7 @@ on:
|
||||
|
||||
jobs:
|
||||
run_librispeech_lstm_transducer_stateless2_2022_09_03:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
@ -128,9 +128,20 @@ jobs:
|
||||
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
- name: Display decoding results for lstm_transducer_stateless2
|
||||
if: github.event.label.name == 'shallow-fusion'
|
||||
shell: bash
|
||||
run: |
|
||||
cd egs/librispeech/ASR
|
||||
tree lstm_transducer_stateless2/exp
|
||||
cd lstm_transducer_stateless2/exp
|
||||
echo "===modified_beam_search_rnnlm_shallow_fusion==="
|
||||
find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
- name: Upload decoding results for lstm_transducer_stateless2
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule'
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
|
||||
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
|
||||
|
@ -2083,7 +2083,7 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
log_prob=hyp_log_prob,
|
||||
state=state,
|
||||
lm_score=lm_score,
|
||||
timestampe=new_timestamp,
|
||||
timestamp=new_timestamp,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
|
@ -101,7 +101,15 @@ class Decoder(nn.Module):
|
||||
need_pad = bool(need_pad)
|
||||
|
||||
y = y.to(torch.int64)
|
||||
embedding_out = self.embedding(y)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
if torch.jit.is_tracing():
|
||||
# This is for exporting to PNNX via ONNX
|
||||
embedding_out = self.embedding(y)
|
||||
else:
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(
|
||||
-1
|
||||
)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad:
|
||||
|
Loading…
x
Reference in New Issue
Block a user