From 63e881f89b7874299cebb798060f385c77140c95 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 5 Dec 2022 23:49:16 +0800 Subject: [PATCH] Pass in dropout from train.py --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b801beccf..3a75a3d5a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -60,6 +60,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from zipformer import Zipformer +from scaling import ScheduledFloat from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -498,7 +499,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: attention_share_layers=to_int_tuple(params.attention_share_layers), feedforward_dim=to_int_tuple(params.feedforward_dim), cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), - dropout=0.1, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, ) return encoder