mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
fix bugs
This commit is contained in:
parent
b18850721d
commit
890cd1ab75
@ -125,7 +125,7 @@ def get_parser():
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
|
||||
"exp_dir": Path("tdnn_lstm_ctc2/exp/"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
@ -136,6 +136,11 @@ def get_params() -> AttributeDict:
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
"grad_norm_threshold": 10.0,
|
||||
# For each sequence element in batch, its gradient will be
|
||||
# filtered out if the gradient norm is larger than
|
||||
# `grad_norm_threshold * median`, where `median` is the median
|
||||
# value of gradient norms of all elememts in batch.
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -452,6 +457,7 @@ def main():
|
||||
num_features=params.feature_dim,
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
grad_norm_threshold=params.grad_norm_threshold,
|
||||
)
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
@ -144,7 +144,10 @@ class TdnnLstm(nn.Module):
|
||||
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
||||
)
|
||||
self.grad_filters = nn.ModuleList(
|
||||
[GradientFilter(batch_dim=1, threshold=grad_norm_threshold)]
|
||||
[
|
||||
GradientFilter(batch_dim=1, threshold=grad_norm_threshold)
|
||||
for _ in range(5)
|
||||
]
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
|
@ -124,6 +124,11 @@ def get_params() -> AttributeDict:
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"grad_norm_threshold": 10.0,
|
||||
# For each sequence element in batch, its gradient will be
|
||||
# filtered out if the gradient norm is larger than
|
||||
# `grad_norm_threshold * median`, where `median` is the median
|
||||
# value of gradient norms of all elememts in batch.
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -172,6 +177,7 @@ def main():
|
||||
num_features=params.feature_dim,
|
||||
num_classes=params.num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
grad_norm_threshold=params.grad_norm_threshold,
|
||||
)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
|
@ -19,7 +19,7 @@
|
||||
"""
|
||||
Usage:
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./tdnn_lstm_ctc/train.py \
|
||||
./tdnn_lstm_ctc2/train.py \
|
||||
--world-size 4 \
|
||||
--full-libri 1 \
|
||||
--max-duration 300 \
|
||||
@ -112,16 +112,6 @@ def get_parser():
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--grad-norm-threshold",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="""For each sequence element in batch, its gradient will be
|
||||
filtered out if the gradient norm is larger than
|
||||
`grad_norm_threshold * median`, where `median` is the median
|
||||
value of gradient norms of all elememts in batch.""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -199,6 +189,11 @@ def get_params() -> AttributeDict:
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
"grad_norm_threshold": 10.0,
|
||||
# For each sequence element in batch, its gradient will be
|
||||
# filtered out if the gradient norm is larger than
|
||||
# `grad_norm_threshold * median`, where `median` is the median
|
||||
# value of gradient norms of all elememts in batch.
|
||||
}
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user