from local

This commit is contained in:
dohe0342 2023-01-09 19:39:39 +09:00
parent 8df7e1fda4
commit 9e21e1f0e1
2 changed files with 15 additions and 21 deletions

View File

@ -480,27 +480,21 @@ class Tempformer(EncoderInterface):
self.short_chunk_size = short_chunk_size self.short_chunk_size = short_chunk_size
self.num_left_chunks = num_left_chunks self.num_left_chunks = num_left_chunks
encoder_layer = ConformerEncoderLayer( def build_conformer(d_model, nhead, dim_feedforward, dropout, layer_dropout, cnn_module_kernel, causal):
d_model=d_model, encoder_layer = ConformerEncoderLayer(
nhead=nhead, d_model=d_model,
dim_feedforward=dim_feedforward, nhead=nhead,
dropout=dropout, dim_feedforward=dim_feedforward,
layer_dropout=layer_dropout, dropout=dropout,
cnn_module_kernel=cnn_module_kernel, layer_dropout=layer_dropout,
causal=causal, cnn_module_kernel=cnn_module_kernel,
) causal=causal,
# aux_layers from 1/3 )
self.encoder = ConformerEncoder( return encoder_layer
encoder_layer=encoder_layer,
num_layers=num_encoder_layers,
aux_layers=list( self.encoder_layers = nn.ModuleList(
range(
num_encoder_layers // 3,
num_encoder_layers - 1,
aux_layer_period,
)
),
)
self._init_state: List[torch.Tensor] = [torch.empty(0)] self._init_state: List[torch.Tensor] = [torch.empty(0)]
def forward( def forward(