Merge branch 'k2-fsa:master' into repeat-k

This commit is contained in:
Yifan Yang 2023-02-05 15:29:30 +08:00 committed by GitHub
commit 355ecea60b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 17 deletions

View File

@ -299,11 +299,11 @@ to run the training part first.
- (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
of each epoch. You can pass ``--epoch`` to of each epoch. You can pass ``--epoch`` to
``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them.
- (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
every ``--save-every-n`` batches. You can pass ``--iter`` to every ``--save-every-n`` batches. You can pass ``--iter`` to
``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them.
We suggest that you try both types of checkpoints and choose the one We suggest that you try both types of checkpoints and choose the one
that produces the lowest WERs. that produces the lowest WERs.
@ -311,7 +311,7 @@ to run the training part first.
.. code-block:: bash .. code-block:: bash
$ cd egs/librispeech/ASR $ cd egs/librispeech/ASR
$ ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py --help $ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py --help
shows the options for decoding. shows the options for decoding.
@ -320,7 +320,7 @@ The following shows the example using ``epoch-*.pt``:
.. code-block:: bash .. code-block:: bash
for m in greedy_search fast_beam_search modified_beam_search; do for m in greedy_search fast_beam_search modified_beam_search; do
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 30 \ --epoch 30 \
--avg 13 \ --avg 13 \
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \
@ -333,7 +333,7 @@ To test CTC branch, you can use the following command:
.. code-block:: bash .. code-block:: bash
for m in ctc-decoding 1best; do for m in ctc-decoding 1best; do
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 30 \ --epoch 30 \
--avg 13 \ --avg 13 \
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \
@ -367,7 +367,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p
.. hint:: .. hint::
To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``, To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``,
you can run: you can run:
.. code-block:: bash .. code-block:: bash
@ -376,7 +376,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p
ln -s pretrained epoch-9999.pt ln -s pretrained epoch-9999.pt
And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
``./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``. ``./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``.
To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you
can run: can run:

View File

@ -194,7 +194,7 @@ The decoding commands for the transducer branch of the model using blank skip ([
for m in greedy_search modified_beam_search fast_beam_search; do for m in greedy_search modified_beam_search fast_beam_search; do
for epoch in 30; do for epoch in 30; do
for avg in 15; do for avg in 15; do
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch $epoch \ --epoch $epoch \
--avg $avg \ --avg $avg \
--use-averaged-model 1 \ --use-averaged-model 1 \

View File

@ -21,7 +21,7 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -29,7 +29,7 @@ Usage:
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -38,7 +38,7 @@ Usage:
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -47,7 +47,7 @@ Usage:
--beam-size 4 --beam-size 4
(4) fast beam search (one best) (4) fast beam search (one best)
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -58,7 +58,7 @@ Usage:
--max-states 64 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless7_ctc/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc/ctc_guide_decode_bs.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \
@ -71,7 +71,7 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(6) fast beam search (nbest oracle WER) (6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -84,7 +84,7 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG) (7) fast beam search (with LG)
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \

View File

@ -44,6 +44,7 @@ class FrameReducer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
ctc_output: torch.Tensor, ctc_output: torch.Tensor,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -54,6 +55,8 @@ class FrameReducer(nn.Module):
`x` before padding. `x` before padding.
ctc_output: ctc_output:
The CTC output with shape [N, T, vocab_size]. The CTC output with shape [N, T, vocab_size].
blank_id:
The blank id of ctc_output.
Returns: Returns:
out: out:
The frame reduced encoder output with shape [N, T', C]. The frame reduced encoder output with shape [N, T', C].
@ -65,7 +68,7 @@ class FrameReducer(nn.Module):
N, T, C = x.size() N, T, C = x.size()
padding_mask = make_pad_mask(x_lens) padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
out_lens = non_blank_mask.sum(dim=1) out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max() max_len = out_lens.max()

View File

@ -62,7 +62,7 @@ class LConv(nn.Module):
kernel_size=kernel_size, kernel_size=kernel_size,
stride=1, stride=1,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
groups=channels, groups=2 * channels,
bias=bias, bias=bias,
) )