mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Pass in dropout from train.py
This commit is contained in:
parent
22617da725
commit
63e881f89b
@ -60,6 +60,7 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from zipformer import Zipformer
|
||||
from scaling import ScheduledFloat
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
@ -498,7 +499,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
attention_share_layers=to_int_tuple(params.attention_share_layers),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dim),
|
||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
||||
dropout=0.1,
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
)
|
||||
return encoder
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user