mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use typeguard.check_argument_types() to validate type annotations.
This commit is contained in:
parent
3d38f7bd31
commit
273c48d94d
@ -23,11 +23,12 @@ as a reference.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class LayerNormLSTMCell(nn.Module):
|
||||
@ -58,7 +59,7 @@ class LayerNormLSTMCell(nn.Module):
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
ln: Type[nn.Module] = nn.LayerNorm,
|
||||
proj_size: int = 0,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -85,6 +86,7 @@ class LayerNormLSTMCell(nn.Module):
|
||||
case, the shape of `h` is (batch_size, proj_size).
|
||||
See https://arxiv.org/pdf/1402.1128.pdf
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.input_size = input_size
|
||||
@ -163,6 +165,7 @@ class LayerNormLSTMCell(nn.Module):
|
||||
- `next_c`: It is of shape (batch_size, hidden_size) containing the
|
||||
next cell state for each element in the batch.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if state is None:
|
||||
zeros = torch.zeros(
|
||||
input.size(0),
|
||||
@ -233,7 +236,7 @@ class LayerNormLSTMLayer(nn.Module):
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
ln: Type[nn.Module] = nn.LayerNorm,
|
||||
proj_size: int = 0,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -241,6 +244,7 @@ class LayerNormLSTMLayer(nn.Module):
|
||||
"""
|
||||
See the args in LayerNormLSTMCell
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.cell = LayerNormLSTMCell(
|
||||
input_size=input_size,
|
||||
@ -309,13 +313,14 @@ class LayerNormLSTM(nn.Module):
|
||||
num_layers: int,
|
||||
bias: bool = True,
|
||||
proj_size: int = 0,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
ln: Type[nn.Module] = nn.LayerNorm,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
See the args in LayerNormLSTMLayer.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
assert num_layers >= 1
|
||||
factory_kwargs = dict(
|
||||
@ -398,7 +403,7 @@ class LayerNormGRUCell(nn.Module):
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
ln: Type[nn.Module] = nn.LayerNorm,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
@ -418,6 +423,7 @@ class LayerNormGRUCell(nn.Module):
|
||||
by `ln`. We pass it as an argument so that we can replace it
|
||||
with `nn.Identity` at the testing time.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.input_size = input_size
|
||||
@ -525,13 +531,14 @@ class LayerNormGRULayer(nn.Module):
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
ln: Type[nn.Module] = nn.LayerNorm,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
See the args in LayerNormGRUCell
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.cell = LayerNormGRUCell(
|
||||
input_size=input_size,
|
||||
@ -591,13 +598,14 @@ class LayerNormGRU(nn.Module):
|
||||
hidden_size: int,
|
||||
num_layers: int,
|
||||
bias: bool = True,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
ln: Type[nn.Module] = nn.LayerNorm,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
See the args in LayerNormGRULayer.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
assert num_layers >= 1
|
||||
factory_kwargs = dict(
|
||||
|
Loading…
x
Reference in New Issue
Block a user