diff --git a/egs/aishell/ASR/pruned_transducer_stateless/README.md b/egs/aishell/ASR/pruned_transducer_stateless/README.md deleted file mode 100644 index 622cb837c..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless/README.md +++ /dev/null @@ -1,21 +0,0 @@ -## Introduction - -The decoder, i.e., the prediction network, is from -https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 -(Rnn-Transducer with Stateless Prediction Network) - -You can use the following command to start the training: - -```bash -cd egs/aishell/ASR - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./transducer_stateless/train.py \ - --world-size 8 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir transducer_stateless/exp \ - --max-duration 250 \ - --lr-factor 2.5 -``` diff --git a/egs/aishell/ASR/pruned_transducer_stateless/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless/beam_search.py index 3441bd20c..70052a474 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/pruned_transducer_stateless/beam_search.py @@ -128,7 +128,6 @@ class HypothesisList(object): def data(self): return self._data - # def add(self, ys: List[int], log_prob: float): def add(self, hyp: Hypothesis): """Add a Hypothesis to `self`. @@ -266,7 +265,7 @@ def beam_search( while t < T and sym_per_utt < max_sym_per_utt: # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) # fmt: on A = B B = HypothesisList() @@ -294,7 +293,9 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1) + ) # TODO(fangjun): Scale the blank posterior diff --git a/egs/aishell/ASR/pruned_transducer_stateless/decode.py b/egs/aishell/ASR/pruned_transducer_stateless/decode.py index 9a1d578c5..be77dca53 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless/decode.py @@ -127,11 +127,11 @@ def get_params() -> AttributeDict: { # parameters for conformer "feature_dim": 80, - "embedding_dim": 256, + "embedding_dim": 512, "subsampling_factor": 4, - "attention_dim": 256, + "attention_dim": 512, "nhead": 4, - "dim_feedforward": 1024, + "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, "env_info": get_env_info(), diff --git a/egs/aishell/ASR/pruned_transducer_stateless/export.py b/egs/aishell/ASR/pruned_transducer_stateless/export.py index 0d2b5a6bf..82e7ebf61 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless/export.py @@ -121,11 +121,11 @@ def get_params() -> AttributeDict: { # parameters for conformer "feature_dim": 80, - "embedding_dim": 256, + "embedding_dim": 512, "subsampling_factor": 4, - "attention_dim": 256, + "attention_dim": 512, "nhead": 4, - "dim_feedforward": 1024, + "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, "env_info": get_env_info(), diff --git a/egs/aishell/ASR/pruned_transducer_stateless/test_decoder.py b/egs/aishell/ASR/pruned_transducer_stateless/test_decoder.py deleted file mode 100755 index 0d34cd672..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless/test_decoder.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -To run this file, do: - - cd icefall/egs/aishell/ASR - python ./transducer_stateless/test_decoder.py -""" - -import torch -from decoder import Decoder - - -def test_decoder(): - vocab_size = 3 - blank_id = 0 - embedding_dim = 128 - context_size = 4 - - decoder = Decoder( - vocab_size=vocab_size, - embedding_dim=embedding_dim, - blank_id=blank_id, - context_size=context_size, - ) - N = 100 - U = 20 - x = torch.randint(low=0, high=vocab_size, size=(N, U)) - y = decoder(x) - assert y.shape == (N, U, vocab_size) - - # for inference - x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) - y = decoder(x, need_pad=False) - assert y.shape == (N, 1, vocab_size) - - -def main(): - test_decoder() - - -if __name__ == "__main__": - main()