Add pruned RNN-T multi-dataset setup for aishell

This commit is contained in:
Fangjun Kuang 2022-06-13 19:05:03 +08:00
parent 09514259b8
commit 6a8ecc3868
3 changed files with 5 additions and 4 deletions

View File

@ -14,6 +14,7 @@ The following table lists the differences among them.
| `transducer_stateless` | Conformer | Embedding + Conv1d | with `k2.rnnt_loss` | | `transducer_stateless` | Conformer | Embedding + Conv1d | with `k2.rnnt_loss` |
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` | | `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data | | `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
| `pruned_transducer_stateless3` | Reworked Conformer | Embedding + Conv1d | Multi dataset: Pruned RNN-T + aishell + aidatatang_200zh|
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -33,7 +33,7 @@ def test_model():
params.blank_id = 0 params.blank_id = 0
params.context_size = 2 params.context_size = 2
params.unk_id = 2 params.unk_id = 2
params.num_encoder_layers = 36 params.num_encoder_layers = 24
params.dim_feedforward = 1024 params.dim_feedforward = 1024
params.nhead = 8 params.nhead = 8
params.encoder_dim = 256 params.encoder_dim = 256

View File

@ -101,14 +101,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=int, type=int,
default=24, default=36,
help="Number of conformer encoder layers..", help="Number of conformer encoder layers..",
) )
parser.add_argument( parser.add_argument(
"--dim-feedforward", "--dim-feedforward",
type=int, type=int,
default=1536, default=1024,
help="Feedforward dimension of the conformer encoder layer.", help="Feedforward dimension of the conformer encoder layer.",
) )
@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=int, type=int,
default=384, default=256,
help="Attention dimension in the conformer encoder layer.", help="Attention dimension in the conformer encoder layer.",
) )