mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Update tests.
This commit is contained in:
parent
896993714b
commit
59478b1ef3
34
.github/workflows/test.yml
vendored
34
.github/workflows/test.yml
vendored
@ -103,11 +103,26 @@ jobs:
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless3
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless4
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||
cd ../transducer
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
cd ../transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_lstm
|
||||
@ -128,11 +143,26 @@ jobs:
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless3
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless4
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||
cd ../transducer
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
cd ../transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_lstm
|
||||
|
@ -29,6 +29,7 @@ from decoder import Decoder
|
||||
def test_decoder():
|
||||
vocab_size = 3
|
||||
blank_id = 0
|
||||
unk_id = 2
|
||||
embedding_dim = 128
|
||||
context_size = 4
|
||||
|
||||
@ -36,6 +37,7 @@ def test_decoder():
|
||||
vocab_size=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
blank_id=blank_id,
|
||||
unk_id=unk_id,
|
||||
context_size=context_size,
|
||||
)
|
||||
N = 100
|
||||
|
@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface):
|
||||
)
|
||||
|
||||
if False:
|
||||
# It is commented out as DPP complains that not all parameters are
|
||||
# It is commented out as DDP complains that not all parameters are
|
||||
# used. Need more checks later for the reason.
|
||||
#
|
||||
# Caution: We assume the dataloader returns utterances with
|
||||
@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface):
|
||||
)
|
||||
|
||||
packed_rnn_out, _ = self.rnn(packed_x)
|
||||
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True)
|
||||
rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True)
|
||||
else:
|
||||
rnn_out, _ = self.rnn(x)
|
||||
|
||||
|
@ -97,8 +97,7 @@ class Transducer(nn.Module):
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_id = self.decoder.sos_id
|
||||
sos_y = add_sos(y, sos_id=sos_id)
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
Loading…
x
Reference in New Issue
Block a user