mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
Minor fixes
This commit is contained in:
parent
96a8e8900b
commit
a4896fbda6
@ -1,21 +0,0 @@
|
|||||||
## Introduction
|
|
||||||
|
|
||||||
The decoder, i.e., the prediction network, is from
|
|
||||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
|
||||||
(Rnn-Transducer with Stateless Prediction Network)
|
|
||||||
|
|
||||||
You can use the following command to start the training:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd egs/aishell/ASR
|
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|
||||||
|
|
||||||
./transducer_stateless/train.py \
|
|
||||||
--world-size 8 \
|
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 0 \
|
|
||||||
--exp-dir transducer_stateless/exp \
|
|
||||||
--max-duration 250 \
|
|
||||||
--lr-factor 2.5
|
|
||||||
```
|
|
@ -128,7 +128,6 @@ class HypothesisList(object):
|
|||||||
def data(self):
|
def data(self):
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
# def add(self, ys: List[int], log_prob: float):
|
|
||||||
def add(self, hyp: Hypothesis):
|
def add(self, hyp: Hypothesis):
|
||||||
"""Add a Hypothesis to `self`.
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
@ -266,7 +265,7 @@ def beam_search(
|
|||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
while t < T and sym_per_utt < max_sym_per_utt:
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
A = B
|
A = B
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
@ -294,7 +293,9 @@ def beam_search(
|
|||||||
|
|
||||||
cached_key += f"-t-{t}"
|
cached_key += f"-t-{t}"
|
||||||
if cached_key not in joint_cache:
|
if cached_key not in joint_cache:
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out.unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Scale the blank posterior
|
# TODO(fangjun): Scale the blank posterior
|
||||||
|
|
||||||
|
@ -127,11 +127,11 @@ def get_params() -> AttributeDict:
|
|||||||
{
|
{
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"embedding_dim": 256,
|
"embedding_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 256,
|
"attention_dim": 512,
|
||||||
"nhead": 4,
|
"nhead": 4,
|
||||||
"dim_feedforward": 1024,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
|
@ -121,11 +121,11 @@ def get_params() -> AttributeDict:
|
|||||||
{
|
{
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"embedding_dim": 256,
|
"embedding_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 256,
|
"attention_dim": 512,
|
||||||
"nhead": 4,
|
"nhead": 4,
|
||||||
"dim_feedforward": 1024,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
|
@ -1,58 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
To run this file, do:
|
|
||||||
|
|
||||||
cd icefall/egs/aishell/ASR
|
|
||||||
python ./transducer_stateless/test_decoder.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from decoder import Decoder
|
|
||||||
|
|
||||||
|
|
||||||
def test_decoder():
|
|
||||||
vocab_size = 3
|
|
||||||
blank_id = 0
|
|
||||||
embedding_dim = 128
|
|
||||||
context_size = 4
|
|
||||||
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
blank_id=blank_id,
|
|
||||||
context_size=context_size,
|
|
||||||
)
|
|
||||||
N = 100
|
|
||||||
U = 20
|
|
||||||
x = torch.randint(low=0, high=vocab_size, size=(N, U))
|
|
||||||
y = decoder(x)
|
|
||||||
assert y.shape == (N, U, vocab_size)
|
|
||||||
|
|
||||||
# for inference
|
|
||||||
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
|
||||||
y = decoder(x, need_pad=False)
|
|
||||||
assert y.shape == (N, 1, vocab_size)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
test_decoder()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Loading…
x
Reference in New Issue
Block a user