Use typeguard.check_argument_types() to validate type annotations.

This commit is contained in:
Fangjun Kuang 2021-12-04 00:08:08 +08:00
parent 3d38f7bd31
commit 273c48d94d

View File

@ -23,11 +23,12 @@ as a reference.
"""
import math
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from typeguard import check_argument_types
class LayerNormLSTMCell(nn.Module):
@ -58,7 +59,7 @@ class LayerNormLSTMCell(nn.Module):
input_size: int,
hidden_size: int,
bias: bool = True,
ln: nn.Module = nn.LayerNorm,
ln: Type[nn.Module] = nn.LayerNorm,
proj_size: int = 0,
device=None,
dtype=None,
@ -85,6 +86,7 @@ class LayerNormLSTMCell(nn.Module):
case, the shape of `h` is (batch_size, proj_size).
See https://arxiv.org/pdf/1402.1128.pdf
"""
assert check_argument_types()
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.input_size = input_size
@ -163,6 +165,7 @@ class LayerNormLSTMCell(nn.Module):
- `next_c`: It is of shape (batch_size, hidden_size) containing the
next cell state for each element in the batch.
"""
assert check_argument_types()
if state is None:
zeros = torch.zeros(
input.size(0),
@ -233,7 +236,7 @@ class LayerNormLSTMLayer(nn.Module):
input_size: int,
hidden_size: int,
bias: bool = True,
ln: nn.Module = nn.LayerNorm,
ln: Type[nn.Module] = nn.LayerNorm,
proj_size: int = 0,
device=None,
dtype=None,
@ -241,6 +244,7 @@ class LayerNormLSTMLayer(nn.Module):
"""
See the args in LayerNormLSTMCell
"""
assert check_argument_types()
super().__init__()
self.cell = LayerNormLSTMCell(
input_size=input_size,
@ -309,13 +313,14 @@ class LayerNormLSTM(nn.Module):
num_layers: int,
bias: bool = True,
proj_size: int = 0,
ln: nn.Module = nn.LayerNorm,
ln: Type[nn.Module] = nn.LayerNorm,
device=None,
dtype=None,
):
"""
See the args in LayerNormLSTMLayer.
"""
assert check_argument_types()
super().__init__()
assert num_layers >= 1
factory_kwargs = dict(
@ -398,7 +403,7 @@ class LayerNormGRUCell(nn.Module):
input_size: int,
hidden_size: int,
bias: bool = True,
ln: nn.Module = nn.LayerNorm,
ln: Type[nn.Module] = nn.LayerNorm,
device=None,
dtype=None,
):
@ -418,6 +423,7 @@ class LayerNormGRUCell(nn.Module):
by `ln`. We pass it as an argument so that we can replace it
with `nn.Identity` at the testing time.
"""
assert check_argument_types()
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.input_size = input_size
@ -525,13 +531,14 @@ class LayerNormGRULayer(nn.Module):
input_size: int,
hidden_size: int,
bias: bool = True,
ln: nn.Module = nn.LayerNorm,
ln: Type[nn.Module] = nn.LayerNorm,
device=None,
dtype=None,
):
"""
See the args in LayerNormGRUCell
"""
assert check_argument_types()
super().__init__()
self.cell = LayerNormGRUCell(
input_size=input_size,
@ -591,13 +598,14 @@ class LayerNormGRU(nn.Module):
hidden_size: int,
num_layers: int,
bias: bool = True,
ln: nn.Module = nn.LayerNorm,
ln: Type[nn.Module] = nn.LayerNorm,
device=None,
dtype=None,
):
"""
See the args in LayerNormGRULayer.
"""
assert check_argument_types()
super().__init__()
assert num_layers >= 1
factory_kwargs = dict(