from local

This commit is contained in:
dohe0342 2023-02-14 17:53:48 +09:00
parent f7950e6bdb
commit b6a0ae8b61
3 changed files with 3 additions and 1 deletions

View File

@ -144,6 +144,8 @@ class Conformer(Transformer):
self.interctc_condition = interctc_condition
if self.interctc_condition:
self.condition_layer = ScaledLinear(500, d_model)
else:
self.condition_layer = None
def run_encoder(
self,
@ -179,7 +181,7 @@ class Conformer(Transformer):
mask = mask.to(x.device) if mask is not None else None
x, layer_outputs = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
x, pos_emb, src_key_padding_mask=mask, warmup=warmup,
) # (S, N, C)
if self.group_num != 0: