mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Use typeguard.check_argument_types() to validate type annotations.
This commit is contained in:
parent
3d38f7bd31
commit
273c48d94d
@ -23,11 +23,12 @@ as a reference.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
|
||||||
class LayerNormLSTMCell(nn.Module):
|
class LayerNormLSTMCell(nn.Module):
|
||||||
@ -58,7 +59,7 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
input_size: int,
|
input_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: Type[nn.Module] = nn.LayerNorm,
|
||||||
proj_size: int = 0,
|
proj_size: int = 0,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -85,6 +86,7 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
case, the shape of `h` is (batch_size, proj_size).
|
case, the shape of `h` is (batch_size, proj_size).
|
||||||
See https://arxiv.org/pdf/1402.1128.pdf
|
See https://arxiv.org/pdf/1402.1128.pdf
|
||||||
"""
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
@ -163,6 +165,7 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
- `next_c`: It is of shape (batch_size, hidden_size) containing the
|
- `next_c`: It is of shape (batch_size, hidden_size) containing the
|
||||||
next cell state for each element in the batch.
|
next cell state for each element in the batch.
|
||||||
"""
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
if state is None:
|
if state is None:
|
||||||
zeros = torch.zeros(
|
zeros = torch.zeros(
|
||||||
input.size(0),
|
input.size(0),
|
||||||
@ -233,7 +236,7 @@ class LayerNormLSTMLayer(nn.Module):
|
|||||||
input_size: int,
|
input_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: Type[nn.Module] = nn.LayerNorm,
|
||||||
proj_size: int = 0,
|
proj_size: int = 0,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -241,6 +244,7 @@ class LayerNormLSTMLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
See the args in LayerNormLSTMCell
|
See the args in LayerNormLSTMCell
|
||||||
"""
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cell = LayerNormLSTMCell(
|
self.cell = LayerNormLSTMCell(
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
@ -309,13 +313,14 @@ class LayerNormLSTM(nn.Module):
|
|||||||
num_layers: int,
|
num_layers: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
proj_size: int = 0,
|
proj_size: int = 0,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: Type[nn.Module] = nn.LayerNorm,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
See the args in LayerNormLSTMLayer.
|
See the args in LayerNormLSTMLayer.
|
||||||
"""
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_layers >= 1
|
assert num_layers >= 1
|
||||||
factory_kwargs = dict(
|
factory_kwargs = dict(
|
||||||
@ -398,7 +403,7 @@ class LayerNormGRUCell(nn.Module):
|
|||||||
input_size: int,
|
input_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: Type[nn.Module] = nn.LayerNorm,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
@ -418,6 +423,7 @@ class LayerNormGRUCell(nn.Module):
|
|||||||
by `ln`. We pass it as an argument so that we can replace it
|
by `ln`. We pass it as an argument so that we can replace it
|
||||||
with `nn.Identity` at the testing time.
|
with `nn.Identity` at the testing time.
|
||||||
"""
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
@ -525,13 +531,14 @@ class LayerNormGRULayer(nn.Module):
|
|||||||
input_size: int,
|
input_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: Type[nn.Module] = nn.LayerNorm,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
See the args in LayerNormGRUCell
|
See the args in LayerNormGRUCell
|
||||||
"""
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cell = LayerNormGRUCell(
|
self.cell = LayerNormGRUCell(
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
@ -591,13 +598,14 @@ class LayerNormGRU(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: Type[nn.Module] = nn.LayerNorm,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
See the args in LayerNormGRULayer.
|
See the args in LayerNormGRULayer.
|
||||||
"""
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_layers >= 1
|
assert num_layers >= 1
|
||||||
factory_kwargs = dict(
|
factory_kwargs = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user