Fix warnings.

This commit is contained in:
Fangjun Kuang 2022-05-05 09:58:40 +08:00
parent e85c2eaba0
commit 893aecaaa2
2 changed files with 5 additions and 3 deletions

View File

@ -51,9 +51,10 @@ class Transducer(nn.Module):
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)

View File

@ -1314,6 +1314,7 @@ def _test_random_combine_main():
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
if __name__ == "__main__":