from local

This commit is contained in:
dohe0342 2023-01-29 19:28:37 +09:00
parent 8f09b5000e
commit eabb297212
2 changed files with 20 additions and 10 deletions

View File

@ -693,17 +693,27 @@ class ConformerEncoder(nn.Module):
output = src
outputs = []
residual = None
for i, mod in enumerate(self.layers):
if random.random() < 0.05:
continue
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
if i in [2,5,8]:
residual = output
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
output += residual
else:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
#if i in self.aux_layers:
# outputs.append(output)