mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 06:04:18 +00:00
Minor fixes
This commit is contained in:
parent
96a8e8900b
commit
a4896fbda6
@ -1,21 +0,0 @@
|
||||
## Introduction
|
||||
|
||||
The decoder, i.e., the prediction network, is from
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
(Rnn-Transducer with Stateless Prediction Network)
|
||||
|
||||
You can use the following command to start the training:
|
||||
|
||||
```bash
|
||||
cd egs/aishell/ASR
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
./transducer_stateless/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--max-duration 250 \
|
||||
--lr-factor 2.5
|
||||
```
|
@ -128,7 +128,6 @@ class HypothesisList(object):
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
# def add(self, ys: List[int], log_prob: float):
|
||||
def add(self, hyp: Hypothesis):
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
@ -266,7 +265,7 @@ def beam_search(
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
A = B
|
||||
B = HypothesisList()
|
||||
@ -294,7 +293,9 @@ def beam_search(
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
|
||||
|
@ -127,11 +127,11 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"embedding_dim": 256,
|
||||
"embedding_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 256,
|
||||
"attention_dim": 512,
|
||||
"nhead": 4,
|
||||
"dim_feedforward": 1024,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"env_info": get_env_info(),
|
||||
|
@ -121,11 +121,11 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"embedding_dim": 256,
|
||||
"embedding_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 256,
|
||||
"attention_dim": 512,
|
||||
"nhead": 4,
|
||||
"dim_feedforward": 1024,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"env_info": get_env_info(),
|
||||
|
@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/aishell/ASR
|
||||
python ./transducer_stateless/test_decoder.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from decoder import Decoder
|
||||
|
||||
|
||||
def test_decoder():
|
||||
vocab_size = 3
|
||||
blank_id = 0
|
||||
embedding_dim = 128
|
||||
context_size = 4
|
||||
|
||||
decoder = Decoder(
|
||||
vocab_size=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
blank_id=blank_id,
|
||||
context_size=context_size,
|
||||
)
|
||||
N = 100
|
||||
U = 20
|
||||
x = torch.randint(low=0, high=vocab_size, size=(N, U))
|
||||
y = decoder(x)
|
||||
assert y.shape == (N, U, vocab_size)
|
||||
|
||||
# for inference
|
||||
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
||||
y = decoder(x, need_pad=False)
|
||||
assert y.shape == (N, 1, vocab_size)
|
||||
|
||||
|
||||
def main():
|
||||
test_decoder()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user