Update tests.

This commit is contained in:
Fangjun Kuang 2022-05-16 19:22:21 +08:00
parent 896993714b
commit 59478b1ef3
4 changed files with 37 additions and 6 deletions

View File

@ -103,11 +103,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc cd egs/librispeech/ASR/conformer_ctc
pytest -v -s pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer cd ../transducer
pytest -v -s pytest -v -s
cd ../transducer_stateless cd ../transducer_stateless2
pytest -v -s pytest -v -s
cd ../transducer_lstm cd ../transducer_lstm
@ -128,11 +143,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc cd egs/librispeech/ASR/conformer_ctc
pytest -v -s pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer cd ../transducer
pytest -v -s pytest -v -s
cd ../transducer_stateless cd ../transducer_stateless2
pytest -v -s pytest -v -s
cd ../transducer_lstm cd ../transducer_lstm

View File

@ -29,6 +29,7 @@ from decoder import Decoder
def test_decoder(): def test_decoder():
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
unk_id = 2
embedding_dim = 128 embedding_dim = 128
context_size = 4 context_size = 4
@ -36,6 +37,7 @@ def test_decoder():
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
unk_id=unk_id,
context_size=context_size, context_size=context_size,
) )
N = 100 N = 100

View File

@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface):
) )
if False: if False:
# It is commented out as DPP complains that not all parameters are # It is commented out as DDP complains that not all parameters are
# used. Need more checks later for the reason. # used. Need more checks later for the reason.
# #
# Caution: We assume the dataloader returns utterances with # Caution: We assume the dataloader returns utterances with
@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface):
) )
packed_rnn_out, _ = self.rnn(packed_x) packed_rnn_out, _ = self.rnn(packed_x)
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True) rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True)
else: else:
rnn_out, _ = self.rnn(x) rnn_out, _ = self.rnn(x)

View File

@ -97,8 +97,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id sos_y = add_sos(y, sos_id=blank_id)
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64) sos_y_padded = sos_y_padded.to(torch.int64)