diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py index 062a27645..92ef17dcc 100644 --- a/egs/librispeech/ASR/transducer/rnn.py +++ b/egs/librispeech/ASR/transducer/rnn.py @@ -23,11 +23,12 @@ as a reference. """ import math -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type import torch import torch.nn as nn import torch.nn.functional as F +from typeguard import check_argument_types class LayerNormLSTMCell(nn.Module): @@ -58,7 +59,7 @@ class LayerNormLSTMCell(nn.Module): input_size: int, hidden_size: int, bias: bool = True, - ln: nn.Module = nn.LayerNorm, + ln: Type[nn.Module] = nn.LayerNorm, proj_size: int = 0, device=None, dtype=None, @@ -85,6 +86,7 @@ class LayerNormLSTMCell(nn.Module): case, the shape of `h` is (batch_size, proj_size). See https://arxiv.org/pdf/1402.1128.pdf """ + assert check_argument_types() super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.input_size = input_size @@ -163,6 +165,7 @@ class LayerNormLSTMCell(nn.Module): - `next_c`: It is of shape (batch_size, hidden_size) containing the next cell state for each element in the batch. """ + assert check_argument_types() if state is None: zeros = torch.zeros( input.size(0), @@ -233,7 +236,7 @@ class LayerNormLSTMLayer(nn.Module): input_size: int, hidden_size: int, bias: bool = True, - ln: nn.Module = nn.LayerNorm, + ln: Type[nn.Module] = nn.LayerNorm, proj_size: int = 0, device=None, dtype=None, @@ -241,6 +244,7 @@ class LayerNormLSTMLayer(nn.Module): """ See the args in LayerNormLSTMCell """ + assert check_argument_types() super().__init__() self.cell = LayerNormLSTMCell( input_size=input_size, @@ -309,13 +313,14 @@ class LayerNormLSTM(nn.Module): num_layers: int, bias: bool = True, proj_size: int = 0, - ln: nn.Module = nn.LayerNorm, + ln: Type[nn.Module] = nn.LayerNorm, device=None, dtype=None, ): """ See the args in LayerNormLSTMLayer. """ + assert check_argument_types() super().__init__() assert num_layers >= 1 factory_kwargs = dict( @@ -398,7 +403,7 @@ class LayerNormGRUCell(nn.Module): input_size: int, hidden_size: int, bias: bool = True, - ln: nn.Module = nn.LayerNorm, + ln: Type[nn.Module] = nn.LayerNorm, device=None, dtype=None, ): @@ -418,6 +423,7 @@ class LayerNormGRUCell(nn.Module): by `ln`. We pass it as an argument so that we can replace it with `nn.Identity` at the testing time. """ + assert check_argument_types() super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.input_size = input_size @@ -525,13 +531,14 @@ class LayerNormGRULayer(nn.Module): input_size: int, hidden_size: int, bias: bool = True, - ln: nn.Module = nn.LayerNorm, + ln: Type[nn.Module] = nn.LayerNorm, device=None, dtype=None, ): """ See the args in LayerNormGRUCell """ + assert check_argument_types() super().__init__() self.cell = LayerNormGRUCell( input_size=input_size, @@ -591,13 +598,14 @@ class LayerNormGRU(nn.Module): hidden_size: int, num_layers: int, bias: bool = True, - ln: nn.Module = nn.LayerNorm, + ln: Type[nn.Module] = nn.LayerNorm, device=None, dtype=None, ): """ See the args in LayerNormGRULayer. """ + assert check_argument_types() super().__init__() assert num_layers >= 1 factory_kwargs = dict(