mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
don't do penalize_values_gt on simple_lm_proj and simple_am_proj; reduce --base-lr from 0.1 to 0.075
This commit is contained in:
parent
269b70122e
commit
2964628ae1
@ -142,10 +142,10 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
if self.training and random.random() < 0.25:
|
#if self.training and random.random() < 0.25:
|
||||||
lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||||
if self.training and random.random() < 0.25:
|
#if self.training and random.random() < 0.25:
|
||||||
am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
|
|||||||
@ -232,7 +232,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr",
|
"--base-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.1,
|
default=0.075,
|
||||||
help="The base learning rate."
|
help="The base learning rate."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user