diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc2/__init__.py b/egs/librispeech/ASR/tdnn_lstm_ctc2/__init__.py new file mode 120000 index 000000000..6ab208c6d --- /dev/null +++ b/egs/librispeech/ASR/tdnn_lstm_ctc2/__init__.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/__init__.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc2/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc2/model.py index 5e04c11b4..f8e5f5e9b 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc2/model.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc2/model.py @@ -15,13 +15,77 @@ # limitations under the License. +from typing import Tuple + import torch import torch.nn as nn +class GradientFilterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + batch_dim: int, # e.g., 1 + threshold: float, # e.g., 10.0 + ) -> torch.Tensor: + if x.requires_grad: + if batch_dim < 0: + batch_dim += x.ndim + ctx.batch_dim = batch_dim + ctx.threshold = threshold + return x + + @staticmethod + def backward(ctx, x_grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]: + dim = ctx.batch_dim + if x_grad.shape[dim] == 1: + return x_grad, None, None + norm_dims = [d for d in range(x_grad.ndim) if d != dim] + norm_of_batch = x_grad.norm(dim=norm_dims, keepdim=True) + norm_of_batch_sorted = norm_of_batch.sort(dim=dim)[0] + median_idx = (x_grad.shape[dim] - 1) // 2 + median_norm = norm_of_batch_sorted.narrow( + dim=dim, start=median_idx, length=1 + ) + mask = norm_of_batch <= ctx.threshold * median_norm + return x_grad * mask, None, None + + +class GradientFilter(torch.nn.Module): + """This is used to filter out elements that have extremely large gradients + in batch. + + Args: + batch_dim (int): + The batch dimension. + threshold (float): + For each element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch. + """ + + def __init__(self, batch_dim: int = 1, threshold: float = 10.0): + super(GradientFilter, self).__init__() + self.batch_dim = batch_dim + self.threshold = threshold + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return GradientFilterFunction.apply( + x, + self.batch_dim, + self.threshold, + ) + + class TdnnLstm(nn.Module): def __init__( - self, num_features: int, num_classes: int, subsampling_factor: int = 3 + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 3, + grad_norm_threshold: float = 10.0, ) -> None: """ Args: @@ -31,6 +95,11 @@ class TdnnLstm(nn.Module): The output dimension of the model. subsampling_factor: It reduces the number of output frames by this factor. + grad_norm_threshold: + For each sequence element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch. """ super().__init__() self.num_features = num_features @@ -74,6 +143,10 @@ class TdnnLstm(nn.Module): self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] ) + self.grad_filters = nn.ModuleList( + [GradientFilter(batch_dim=1, threshold=grad_norm_threshold)] + ) + self.dropout = nn.Dropout(0.2) self.linear = nn.Linear(in_features=500, out_features=self.num_classes) @@ -88,8 +161,10 @@ class TdnnLstm(nn.Module): """ x = self.tdnn(x) x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it - for lstm, bnorm in zip(self.lstms, self.lstm_bnorms): - x_new, _ = lstm(x) + for lstm, bnorm, grad_filter in zip( + self.lstms, self.lstm_bnorms, self.grad_filters + ): + x_new, _ = lstm(grad_filter(x)) x_new = bnorm(x_new.permute(1, 2, 0)).permute( 2, 0, 1 ) # (T, N, C) -> (N, C, T) -> (T, N, C) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc2/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc2/train.py index 6b37d5c23..9497ca09e 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc2/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc2/train.py @@ -112,6 +112,16 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--grad-norm-threshold", + type=float, + default=10.0, + help="""For each sequence element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch.""", + ) + return parser @@ -171,7 +181,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("tdnn_lstm_ctc/exp"), + "exp_dir": Path("tdnn_lstm_ctc2/exp"), "lang_dir": Path("data/lang_phone"), "lr": 1e-3, "feature_dim": 80, @@ -540,6 +550,7 @@ def run(rank, world_size, args): num_features=params.feature_dim, num_classes=max_phone_id + 1, # +1 for the blank symbol subsampling_factor=params.subsampling_factor, + grad_norm_threshold=params.grad_norm_threshold, ) checkpoints = load_checkpoint_if_available(params=params, model=model)