Update tests.

This commit is contained in:
Fangjun Kuang 2022-05-16 19:22:21 +08:00
parent 896993714b
commit 59478b1ef3
4 changed files with 37 additions and 6 deletions

View File

@ -103,11 +103,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer
pytest -v -s
cd ../transducer_stateless
cd ../transducer_stateless2
pytest -v -s
cd ../transducer_lstm
@ -128,11 +143,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer
pytest -v -s
cd ../transducer_stateless
cd ../transducer_stateless2
pytest -v -s
cd ../transducer_lstm

View File

@ -29,6 +29,7 @@ from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
unk_id = 2
embedding_dim = 128
context_size = 4
@ -36,6 +37,7 @@ def test_decoder():
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
unk_id=unk_id,
context_size=context_size,
)
N = 100

View File

@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface):
)
if False:
# It is commented out as DPP complains that not all parameters are
# It is commented out as DDP complains that not all parameters are
# used. Need more checks later for the reason.
#
# Caution: We assume the dataloader returns utterances with
@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface):
)
packed_rnn_out, _ = self.rnn(packed_x)
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True)
rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True)
else:
rnn_out, _ = self.rnn(x)

View File

@ -97,8 +97,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id
sos_y = add_sos(y, sos_id=sos_id)
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)