This commit is contained in:
yaozengwei 2022-09-06 10:23:40 +08:00
parent b18850721d
commit 890cd1ab75
4 changed files with 23 additions and 13 deletions

View File

@ -125,7 +125,7 @@ def get_parser():
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("tdnn_lstm_ctc/exp/"), "exp_dir": Path("tdnn_lstm_ctc2/exp/"),
"lang_dir": Path("data/lang_phone"), "lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
"feature_dim": 80, "feature_dim": 80,
@ -136,6 +136,11 @@ def get_params() -> AttributeDict:
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(), "env_info": get_env_info(),
"grad_norm_threshold": 10.0,
# For each sequence element in batch, its gradient will be
# filtered out if the gradient norm is larger than
# `grad_norm_threshold * median`, where `median` is the median
# value of gradient norms of all elememts in batch.
} }
) )
return params return params
@ -452,6 +457,7 @@ def main():
num_features=params.feature_dim, num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
grad_norm_threshold=params.grad_norm_threshold,
) )
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)

View File

@ -144,7 +144,10 @@ class TdnnLstm(nn.Module):
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
) )
self.grad_filters = nn.ModuleList( self.grad_filters = nn.ModuleList(
[GradientFilter(batch_dim=1, threshold=grad_norm_threshold)] [
GradientFilter(batch_dim=1, threshold=grad_norm_threshold)
for _ in range(5)
]
) )
self.dropout = nn.Dropout(0.2) self.dropout = nn.Dropout(0.2)

View File

@ -124,6 +124,11 @@ def get_params() -> AttributeDict:
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"grad_norm_threshold": 10.0,
# For each sequence element in batch, its gradient will be
# filtered out if the gradient norm is larger than
# `grad_norm_threshold * median`, where `median` is the median
# value of gradient norms of all elememts in batch.
} }
) )
return params return params
@ -172,6 +177,7 @@ def main():
num_features=params.feature_dim, num_features=params.feature_dim,
num_classes=params.num_classes, num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
grad_norm_threshold=params.grad_norm_threshold,
) )
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu")

View File

@ -19,7 +19,7 @@
""" """
Usage: Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./tdnn_lstm_ctc/train.py \ ./tdnn_lstm_ctc2/train.py \
--world-size 4 \ --world-size 4 \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 \ --max-duration 300 \
@ -112,16 +112,6 @@ def get_parser():
help="The seed for random generators intended for reproducibility", help="The seed for random generators intended for reproducibility",
) )
parser.add_argument(
"--grad-norm-threshold",
type=float,
default=10.0,
help="""For each sequence element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.""",
)
return parser return parser
@ -199,6 +189,11 @@ def get_params() -> AttributeDict:
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(), "env_info": get_env_info(),
"grad_norm_threshold": 10.0,
# For each sequence element in batch, its gradient will be
# filtered out if the gradient norm is larger than
# `grad_norm_threshold * median`, where `median` is the median
# value of gradient norms of all elememts in batch.
} }
) )