minor fix

This commit is contained in:
yaozengwei 2022-06-11 22:40:40 +08:00
parent e011812bab
commit 64c9053e39

View File

@ -1646,15 +1646,23 @@ class Emformer(EncoderInterface):
self.right_context_length = right_context_length self.right_context_length = right_context_length
if subsampling_factor != 4: if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.") raise NotImplementedError("Support only 'subsampling_factor=4'.")
if chunk_length % 4 != 0: if chunk_length % subsampling_factor != 0:
raise NotImplementedError("chunk_length must be a mutiple of 4.")
if left_context_length != 0 and left_context_length % 4 != 0:
raise NotImplementedError( raise NotImplementedError(
"left_context_length must be 0 or a mutiple of 4." "chunk_length must be a mutiple of subsampling_factor."
) )
if right_context_length != 0 and right_context_length % 4 != 0: if (
left_context_length != 0
and left_context_length % subsampling_factor != 0
):
raise NotImplementedError( raise NotImplementedError(
"right_context_length must be 0 or a mutiple of 4." "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa
)
if (
right_context_length != 0
and right_context_length % subsampling_factor != 0
):
raise NotImplementedError(
"right_context_length must be 0 or a mutiple of subsampling_factor." # noqa
) )
# self.encoder_embed converts the input of shape (N, T, num_features) # self.encoder_embed converts the input of shape (N, T, num_features)
@ -1665,7 +1673,7 @@ class Emformer(EncoderInterface):
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder = EmformerEncoder( self.encoder = EmformerEncoder(
chunk_length=chunk_length // 4, chunk_length=chunk_length // subsampling_factor,
d_model=d_model, d_model=d_model,
nhead=nhead, nhead=nhead,
dim_feedforward=dim_feedforward, dim_feedforward=dim_feedforward,
@ -1673,8 +1681,8 @@ class Emformer(EncoderInterface):
dropout=dropout, dropout=dropout,
layer_dropout=layer_dropout, layer_dropout=layer_dropout,
cnn_module_kernel=cnn_module_kernel, cnn_module_kernel=cnn_module_kernel,
left_context_length=left_context_length // 4, left_context_length=left_context_length // subsampling_factor,
right_context_length=right_context_length // 4, right_context_length=right_context_length // subsampling_factor,
memory_size=memory_size, memory_size=memory_size,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,