mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
add gradient filter
This commit is contained in:
parent
2cc6137934
commit
b18850721d
1
egs/librispeech/ASR/tdnn_lstm_ctc2/__init__.py
Symbolic link
1
egs/librispeech/ASR/tdnn_lstm_ctc2/__init__.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../tdnn_lstm_ctc/__init__.py
|
@ -15,13 +15,77 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class GradientFilterFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
x: torch.Tensor,
|
||||||
|
batch_dim: int, # e.g., 1
|
||||||
|
threshold: float, # e.g., 10.0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if x.requires_grad:
|
||||||
|
if batch_dim < 0:
|
||||||
|
batch_dim += x.ndim
|
||||||
|
ctx.batch_dim = batch_dim
|
||||||
|
ctx.threshold = threshold
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, x_grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
|
||||||
|
dim = ctx.batch_dim
|
||||||
|
if x_grad.shape[dim] == 1:
|
||||||
|
return x_grad, None, None
|
||||||
|
norm_dims = [d for d in range(x_grad.ndim) if d != dim]
|
||||||
|
norm_of_batch = x_grad.norm(dim=norm_dims, keepdim=True)
|
||||||
|
norm_of_batch_sorted = norm_of_batch.sort(dim=dim)[0]
|
||||||
|
median_idx = (x_grad.shape[dim] - 1) // 2
|
||||||
|
median_norm = norm_of_batch_sorted.narrow(
|
||||||
|
dim=dim, start=median_idx, length=1
|
||||||
|
)
|
||||||
|
mask = norm_of_batch <= ctx.threshold * median_norm
|
||||||
|
return x_grad * mask, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class GradientFilter(torch.nn.Module):
|
||||||
|
"""This is used to filter out elements that have extremely large gradients
|
||||||
|
in batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_dim (int):
|
||||||
|
The batch dimension.
|
||||||
|
threshold (float):
|
||||||
|
For each element in batch, its gradient will be
|
||||||
|
filtered out if the gradient norm is larger than
|
||||||
|
`grad_norm_threshold * median`, where `median` is the median
|
||||||
|
value of gradient norms of all elememts in batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, batch_dim: int = 1, threshold: float = 10.0):
|
||||||
|
super(GradientFilter, self).__init__()
|
||||||
|
self.batch_dim = batch_dim
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return GradientFilterFunction.apply(
|
||||||
|
x,
|
||||||
|
self.batch_dim,
|
||||||
|
self.threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TdnnLstm(nn.Module):
|
class TdnnLstm(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_features: int, num_classes: int, subsampling_factor: int = 3
|
self,
|
||||||
|
num_features: int,
|
||||||
|
num_classes: int,
|
||||||
|
subsampling_factor: int = 3,
|
||||||
|
grad_norm_threshold: float = 10.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -31,6 +95,11 @@ class TdnnLstm(nn.Module):
|
|||||||
The output dimension of the model.
|
The output dimension of the model.
|
||||||
subsampling_factor:
|
subsampling_factor:
|
||||||
It reduces the number of output frames by this factor.
|
It reduces the number of output frames by this factor.
|
||||||
|
grad_norm_threshold:
|
||||||
|
For each sequence element in batch, its gradient will be
|
||||||
|
filtered out if the gradient norm is larger than
|
||||||
|
`grad_norm_threshold * median`, where `median` is the median
|
||||||
|
value of gradient norms of all elememts in batch.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
@ -74,6 +143,10 @@ class TdnnLstm(nn.Module):
|
|||||||
self.lstm_bnorms = nn.ModuleList(
|
self.lstm_bnorms = nn.ModuleList(
|
||||||
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
||||||
)
|
)
|
||||||
|
self.grad_filters = nn.ModuleList(
|
||||||
|
[GradientFilter(batch_dim=1, threshold=grad_norm_threshold)]
|
||||||
|
)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(0.2)
|
self.dropout = nn.Dropout(0.2)
|
||||||
self.linear = nn.Linear(in_features=500, out_features=self.num_classes)
|
self.linear = nn.Linear(in_features=500, out_features=self.num_classes)
|
||||||
|
|
||||||
@ -88,8 +161,10 @@ class TdnnLstm(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x = self.tdnn(x)
|
x = self.tdnn(x)
|
||||||
x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it
|
x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it
|
||||||
for lstm, bnorm in zip(self.lstms, self.lstm_bnorms):
|
for lstm, bnorm, grad_filter in zip(
|
||||||
x_new, _ = lstm(x)
|
self.lstms, self.lstm_bnorms, self.grad_filters
|
||||||
|
):
|
||||||
|
x_new, _ = lstm(grad_filter(x))
|
||||||
x_new = bnorm(x_new.permute(1, 2, 0)).permute(
|
x_new = bnorm(x_new.permute(1, 2, 0)).permute(
|
||||||
2, 0, 1
|
2, 0, 1
|
||||||
) # (T, N, C) -> (N, C, T) -> (T, N, C)
|
) # (T, N, C) -> (N, C, T) -> (T, N, C)
|
||||||
|
@ -112,6 +112,16 @@ def get_parser():
|
|||||||
help="The seed for random generators intended for reproducibility",
|
help="The seed for random generators intended for reproducibility",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad-norm-threshold",
|
||||||
|
type=float,
|
||||||
|
default=10.0,
|
||||||
|
help="""For each sequence element in batch, its gradient will be
|
||||||
|
filtered out if the gradient norm is larger than
|
||||||
|
`grad_norm_threshold * median`, where `median` is the median
|
||||||
|
value of gradient norms of all elememts in batch.""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -171,7 +181,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("tdnn_lstm_ctc/exp"),
|
"exp_dir": Path("tdnn_lstm_ctc2/exp"),
|
||||||
"lang_dir": Path("data/lang_phone"),
|
"lang_dir": Path("data/lang_phone"),
|
||||||
"lr": 1e-3,
|
"lr": 1e-3,
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
@ -540,6 +550,7 @@ def run(rank, world_size, args):
|
|||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
grad_norm_threshold=params.grad_norm_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user