Minor fixes

This commit is contained in:
pkufool 2022-03-09 17:03:20 +08:00
parent 96a8e8900b
commit a4896fbda6
5 changed files with 10 additions and 88 deletions

View File

@ -1,21 +0,0 @@
## Introduction
The decoder, i.e., the prediction network, is from
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
(Rnn-Transducer with Stateless Prediction Network)
You can use the following command to start the training:
```bash
cd egs/aishell/ASR
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./transducer_stateless/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--max-duration 250 \
--lr-factor 2.5
```

View File

@ -128,7 +128,6 @@ class HypothesisList(object):
def data(self):
return self._data
# def add(self, ys: List[int], log_prob: float):
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`.
@ -266,7 +265,7 @@ def beam_search(
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
A = B
B = HypothesisList()
@ -294,7 +293,9 @@ def beam_search(
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(current_encoder_out, decoder_out)
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1)
)
# TODO(fangjun): Scale the blank posterior

View File

@ -127,11 +127,11 @@ def get_params() -> AttributeDict:
{
# parameters for conformer
"feature_dim": 80,
"embedding_dim": 256,
"embedding_dim": 512,
"subsampling_factor": 4,
"attention_dim": 256,
"attention_dim": 512,
"nhead": 4,
"dim_feedforward": 1024,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),

View File

@ -121,11 +121,11 @@ def get_params() -> AttributeDict:
{
# parameters for conformer
"feature_dim": 80,
"embedding_dim": 256,
"embedding_dim": 512,
"subsampling_factor": 4,
"attention_dim": 256,
"attention_dim": 512,
"nhead": 4,
"dim_feedforward": 1024,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),

View File

@ -1,58 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/aishell/ASR
python ./transducer_stateless/test_decoder.py
"""
import torch
from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
embedding_dim = 128
context_size = 4
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
context_size=context_size,
)
N = 100
U = 20
x = torch.randint(low=0, high=vocab_size, size=(N, U))
y = decoder(x)
assert y.shape == (N, U, vocab_size)
# for inference
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
y = decoder(x, need_pad=False)
assert y.shape == (N, 1, vocab_size)
def main():
test_decoder()
if __name__ == "__main__":
main()