mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add gradient filter
This commit is contained in:
parent
2cc6137934
commit
b18850721d
1
egs/librispeech/ASR/tdnn_lstm_ctc2/__init__.py
Symbolic link
1
egs/librispeech/ASR/tdnn_lstm_ctc2/__init__.py
Symbolic link
@ -0,0 +1 @@
|
||||
../tdnn_lstm_ctc/__init__.py
|
@ -15,13 +15,77 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GradientFilterFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: torch.Tensor,
|
||||
batch_dim: int, # e.g., 1
|
||||
threshold: float, # e.g., 10.0
|
||||
) -> torch.Tensor:
|
||||
if x.requires_grad:
|
||||
if batch_dim < 0:
|
||||
batch_dim += x.ndim
|
||||
ctx.batch_dim = batch_dim
|
||||
ctx.threshold = threshold
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
|
||||
dim = ctx.batch_dim
|
||||
if x_grad.shape[dim] == 1:
|
||||
return x_grad, None, None
|
||||
norm_dims = [d for d in range(x_grad.ndim) if d != dim]
|
||||
norm_of_batch = x_grad.norm(dim=norm_dims, keepdim=True)
|
||||
norm_of_batch_sorted = norm_of_batch.sort(dim=dim)[0]
|
||||
median_idx = (x_grad.shape[dim] - 1) // 2
|
||||
median_norm = norm_of_batch_sorted.narrow(
|
||||
dim=dim, start=median_idx, length=1
|
||||
)
|
||||
mask = norm_of_batch <= ctx.threshold * median_norm
|
||||
return x_grad * mask, None, None
|
||||
|
||||
|
||||
class GradientFilter(torch.nn.Module):
|
||||
"""This is used to filter out elements that have extremely large gradients
|
||||
in batch.
|
||||
|
||||
Args:
|
||||
batch_dim (int):
|
||||
The batch dimension.
|
||||
threshold (float):
|
||||
For each element in batch, its gradient will be
|
||||
filtered out if the gradient norm is larger than
|
||||
`grad_norm_threshold * median`, where `median` is the median
|
||||
value of gradient norms of all elememts in batch.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_dim: int = 1, threshold: float = 10.0):
|
||||
super(GradientFilter, self).__init__()
|
||||
self.batch_dim = batch_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return GradientFilterFunction.apply(
|
||||
x,
|
||||
self.batch_dim,
|
||||
self.threshold,
|
||||
)
|
||||
|
||||
|
||||
class TdnnLstm(nn.Module):
|
||||
def __init__(
|
||||
self, num_features: int, num_classes: int, subsampling_factor: int = 3
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 3,
|
||||
grad_norm_threshold: float = 10.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -31,6 +95,11 @@ class TdnnLstm(nn.Module):
|
||||
The output dimension of the model.
|
||||
subsampling_factor:
|
||||
It reduces the number of output frames by this factor.
|
||||
grad_norm_threshold:
|
||||
For each sequence element in batch, its gradient will be
|
||||
filtered out if the gradient norm is larger than
|
||||
`grad_norm_threshold * median`, where `median` is the median
|
||||
value of gradient norms of all elememts in batch.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
@ -74,6 +143,10 @@ class TdnnLstm(nn.Module):
|
||||
self.lstm_bnorms = nn.ModuleList(
|
||||
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
||||
)
|
||||
self.grad_filters = nn.ModuleList(
|
||||
[GradientFilter(batch_dim=1, threshold=grad_norm_threshold)]
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
self.linear = nn.Linear(in_features=500, out_features=self.num_classes)
|
||||
|
||||
@ -88,8 +161,10 @@ class TdnnLstm(nn.Module):
|
||||
"""
|
||||
x = self.tdnn(x)
|
||||
x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it
|
||||
for lstm, bnorm in zip(self.lstms, self.lstm_bnorms):
|
||||
x_new, _ = lstm(x)
|
||||
for lstm, bnorm, grad_filter in zip(
|
||||
self.lstms, self.lstm_bnorms, self.grad_filters
|
||||
):
|
||||
x_new, _ = lstm(grad_filter(x))
|
||||
x_new = bnorm(x_new.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
) # (T, N, C) -> (N, C, T) -> (T, N, C)
|
||||
|
@ -112,6 +112,16 @@ def get_parser():
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--grad-norm-threshold",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="""For each sequence element in batch, its gradient will be
|
||||
filtered out if the gradient norm is larger than
|
||||
`grad_norm_threshold * median`, where `median` is the median
|
||||
value of gradient norms of all elememts in batch.""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -171,7 +181,7 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp"),
|
||||
"exp_dir": Path("tdnn_lstm_ctc2/exp"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lr": 1e-3,
|
||||
"feature_dim": 80,
|
||||
@ -540,6 +550,7 @@ def run(rank, world_size, args):
|
||||
num_features=params.feature_dim,
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
grad_norm_threshold=params.grad_norm_threshold,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user