mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
don't do penalize_values_gt on simple_lm_proj and simple_am_proj; reduce --base-lr from 0.1 to 0.075
This commit is contained in:
parent
269b70122e
commit
2964628ae1
@ -142,10 +142,10 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
if self.training and random.random() < 0.25:
|
||||
lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
if self.training and random.random() < 0.25:
|
||||
am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
#if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
#if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
|
||||
@ -232,7 +232,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--base-lr",
|
||||
type=float,
|
||||
default=0.1,
|
||||
default=0.075,
|
||||
help="The base learning rate."
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user