mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Update tests.
This commit is contained in:
parent
896993714b
commit
59478b1ef3
34
.github/workflows/test.yml
vendored
34
.github/workflows/test.yml
vendored
@ -103,11 +103,26 @@ jobs:
|
|||||||
cd egs/librispeech/ASR/conformer_ctc
|
cd egs/librispeech/ASR/conformer_ctc
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless2
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless3
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless4
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../transducer_stateless
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||||
cd ../transducer
|
cd ../transducer
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_stateless
|
cd ../transducer_stateless2
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_lstm
|
cd ../transducer_lstm
|
||||||
@ -128,11 +143,26 @@ jobs:
|
|||||||
cd egs/librispeech/ASR/conformer_ctc
|
cd egs/librispeech/ASR/conformer_ctc
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless2
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless3
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless4
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../transducer_stateless
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||||
cd ../transducer
|
cd ../transducer
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_stateless
|
cd ../transducer_stateless2
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_lstm
|
cd ../transducer_lstm
|
||||||
|
@ -29,6 +29,7 @@ from decoder import Decoder
|
|||||||
def test_decoder():
|
def test_decoder():
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
blank_id = 0
|
blank_id = 0
|
||||||
|
unk_id = 2
|
||||||
embedding_dim = 128
|
embedding_dim = 128
|
||||||
context_size = 4
|
context_size = 4
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ def test_decoder():
|
|||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
blank_id=blank_id,
|
blank_id=blank_id,
|
||||||
|
unk_id=unk_id,
|
||||||
context_size=context_size,
|
context_size=context_size,
|
||||||
)
|
)
|
||||||
N = 100
|
N = 100
|
||||||
|
@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
# It is commented out as DPP complains that not all parameters are
|
# It is commented out as DDP complains that not all parameters are
|
||||||
# used. Need more checks later for the reason.
|
# used. Need more checks later for the reason.
|
||||||
#
|
#
|
||||||
# Caution: We assume the dataloader returns utterances with
|
# Caution: We assume the dataloader returns utterances with
|
||||||
@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
packed_rnn_out, _ = self.rnn(packed_x)
|
packed_rnn_out, _ = self.rnn(packed_x)
|
||||||
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True)
|
rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True)
|
||||||
else:
|
else:
|
||||||
rnn_out, _ = self.rnn(x)
|
rnn_out, _ = self.rnn(x)
|
||||||
|
|
||||||
|
@ -97,8 +97,7 @@ class Transducer(nn.Module):
|
|||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
sos_id = self.decoder.sos_id
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
sos_y = add_sos(y, sos_id=sos_id)
|
|
||||||
|
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user