This commit is contained in:
yaozengwei 2022-09-06 10:23:40 +08:00
parent b18850721d
commit 890cd1ab75
4 changed files with 23 additions and 13 deletions

View File

@ -125,7 +125,7 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
"exp_dir": Path("tdnn_lstm_ctc2/exp/"),
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
@ -136,6 +136,11 @@ def get_params() -> AttributeDict:
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(),
"grad_norm_threshold": 10.0,
# For each sequence element in batch, its gradient will be
# filtered out if the gradient norm is larger than
# `grad_norm_threshold * median`, where `median` is the median
# value of gradient norms of all elememts in batch.
}
)
return params
@ -452,6 +457,7 @@ def main():
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor,
grad_norm_threshold=params.grad_norm_threshold,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)

View File

@ -144,7 +144,10 @@ class TdnnLstm(nn.Module):
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
)
self.grad_filters = nn.ModuleList(
[GradientFilter(batch_dim=1, threshold=grad_norm_threshold)]
[
GradientFilter(batch_dim=1, threshold=grad_norm_threshold)
for _ in range(5)
]
)
self.dropout = nn.Dropout(0.2)

View File

@ -124,6 +124,11 @@ def get_params() -> AttributeDict:
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"grad_norm_threshold": 10.0,
# For each sequence element in batch, its gradient will be
# filtered out if the gradient norm is larger than
# `grad_norm_threshold * median`, where `median` is the median
# value of gradient norms of all elememts in batch.
}
)
return params
@ -172,6 +177,7 @@ def main():
num_features=params.feature_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
grad_norm_threshold=params.grad_norm_threshold,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")

View File

@ -19,7 +19,7 @@
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./tdnn_lstm_ctc/train.py \
./tdnn_lstm_ctc2/train.py \
--world-size 4 \
--full-libri 1 \
--max-duration 300 \
@ -112,16 +112,6 @@ def get_parser():
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--grad-norm-threshold",
type=float,
default=10.0,
help="""For each sequence element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.""",
)
return parser
@ -199,6 +189,11 @@ def get_params() -> AttributeDict:
"reduction": "sum",
"use_double_scores": True,
"env_info": get_env_info(),
"grad_norm_threshold": 10.0,
# For each sequence element in batch, its gradient will be
# filtered out if the gradient norm is larger than
# `grad_norm_threshold * median`, where `median` is the median
# value of gradient norms of all elememts in batch.
}
)