Use typeguard.check_argument_types() to validate type annotations.

This commit is contained in:
Fangjun Kuang 2021-12-04 00:08:08 +08:00
parent 3d38f7bd31
commit 273c48d94d

View File

@ -23,11 +23,12 @@ as a reference.
""" """
import math import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typeguard import check_argument_types
class LayerNormLSTMCell(nn.Module): class LayerNormLSTMCell(nn.Module):
@ -58,7 +59,7 @@ class LayerNormLSTMCell(nn.Module):
input_size: int, input_size: int,
hidden_size: int, hidden_size: int,
bias: bool = True, bias: bool = True,
ln: nn.Module = nn.LayerNorm, ln: Type[nn.Module] = nn.LayerNorm,
proj_size: int = 0, proj_size: int = 0,
device=None, device=None,
dtype=None, dtype=None,
@ -85,6 +86,7 @@ class LayerNormLSTMCell(nn.Module):
case, the shape of `h` is (batch_size, proj_size). case, the shape of `h` is (batch_size, proj_size).
See https://arxiv.org/pdf/1402.1128.pdf See https://arxiv.org/pdf/1402.1128.pdf
""" """
assert check_argument_types()
super().__init__() super().__init__()
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
self.input_size = input_size self.input_size = input_size
@ -163,6 +165,7 @@ class LayerNormLSTMCell(nn.Module):
- `next_c`: It is of shape (batch_size, hidden_size) containing the - `next_c`: It is of shape (batch_size, hidden_size) containing the
next cell state for each element in the batch. next cell state for each element in the batch.
""" """
assert check_argument_types()
if state is None: if state is None:
zeros = torch.zeros( zeros = torch.zeros(
input.size(0), input.size(0),
@ -233,7 +236,7 @@ class LayerNormLSTMLayer(nn.Module):
input_size: int, input_size: int,
hidden_size: int, hidden_size: int,
bias: bool = True, bias: bool = True,
ln: nn.Module = nn.LayerNorm, ln: Type[nn.Module] = nn.LayerNorm,
proj_size: int = 0, proj_size: int = 0,
device=None, device=None,
dtype=None, dtype=None,
@ -241,6 +244,7 @@ class LayerNormLSTMLayer(nn.Module):
""" """
See the args in LayerNormLSTMCell See the args in LayerNormLSTMCell
""" """
assert check_argument_types()
super().__init__() super().__init__()
self.cell = LayerNormLSTMCell( self.cell = LayerNormLSTMCell(
input_size=input_size, input_size=input_size,
@ -309,13 +313,14 @@ class LayerNormLSTM(nn.Module):
num_layers: int, num_layers: int,
bias: bool = True, bias: bool = True,
proj_size: int = 0, proj_size: int = 0,
ln: nn.Module = nn.LayerNorm, ln: Type[nn.Module] = nn.LayerNorm,
device=None, device=None,
dtype=None, dtype=None,
): ):
""" """
See the args in LayerNormLSTMLayer. See the args in LayerNormLSTMLayer.
""" """
assert check_argument_types()
super().__init__() super().__init__()
assert num_layers >= 1 assert num_layers >= 1
factory_kwargs = dict( factory_kwargs = dict(
@ -398,7 +403,7 @@ class LayerNormGRUCell(nn.Module):
input_size: int, input_size: int,
hidden_size: int, hidden_size: int,
bias: bool = True, bias: bool = True,
ln: nn.Module = nn.LayerNorm, ln: Type[nn.Module] = nn.LayerNorm,
device=None, device=None,
dtype=None, dtype=None,
): ):
@ -418,6 +423,7 @@ class LayerNormGRUCell(nn.Module):
by `ln`. We pass it as an argument so that we can replace it by `ln`. We pass it as an argument so that we can replace it
with `nn.Identity` at the testing time. with `nn.Identity` at the testing time.
""" """
assert check_argument_types()
super().__init__() super().__init__()
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
self.input_size = input_size self.input_size = input_size
@ -525,13 +531,14 @@ class LayerNormGRULayer(nn.Module):
input_size: int, input_size: int,
hidden_size: int, hidden_size: int,
bias: bool = True, bias: bool = True,
ln: nn.Module = nn.LayerNorm, ln: Type[nn.Module] = nn.LayerNorm,
device=None, device=None,
dtype=None, dtype=None,
): ):
""" """
See the args in LayerNormGRUCell See the args in LayerNormGRUCell
""" """
assert check_argument_types()
super().__init__() super().__init__()
self.cell = LayerNormGRUCell( self.cell = LayerNormGRUCell(
input_size=input_size, input_size=input_size,
@ -591,13 +598,14 @@ class LayerNormGRU(nn.Module):
hidden_size: int, hidden_size: int,
num_layers: int, num_layers: int,
bias: bool = True, bias: bool = True,
ln: nn.Module = nn.LayerNorm, ln: Type[nn.Module] = nn.LayerNorm,
device=None, device=None,
dtype=None, dtype=None,
): ):
""" """
See the args in LayerNormGRULayer. See the args in LayerNormGRULayer.
""" """
assert check_argument_types()
super().__init__() super().__init__()
assert num_layers >= 1 assert num_layers >= 1
factory_kwargs = dict( factory_kwargs = dict(