Pass in dropout from train.py

This commit is contained in:
Daniel Povey 2022-12-05 23:49:16 +08:00
parent 22617da725
commit 63e881f89b

View File

@ -60,6 +60,7 @@ import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer
from scaling import ScheduledFloat
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -498,7 +499,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
attention_share_layers=to_int_tuple(params.attention_share_layers),
feedforward_dim=to_int_tuple(params.feedforward_dim),
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=0.1,
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
)
return encoder