mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Update conformer.py (#1200)
* Update conformer.py * Update zipformer.py fix bug in get_dynamic_dropout_rate
This commit is contained in:
parent
bbb03f7962
commit
45d60ef262
@ -865,7 +865,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
return final_dropout_rate
|
return final_dropout_rate
|
||||||
else:
|
else:
|
||||||
return initial_dropout_rate - (
|
return initial_dropout_rate - (
|
||||||
initial_dropout_rate * final_dropout_rate
|
initial_dropout_rate - final_dropout_rate
|
||||||
) * (self.batch_count / warmup_period)
|
) * (self.batch_count / warmup_period)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -230,7 +230,7 @@ class Conformer(Transformer):
|
|||||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||||
) # (T, B, F)
|
) # (T, B, F)
|
||||||
else:
|
else:
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F)
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user