mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
minor fix
This commit is contained in:
parent
e011812bab
commit
64c9053e39
@ -1646,15 +1646,23 @@ class Emformer(EncoderInterface):
|
|||||||
self.right_context_length = right_context_length
|
self.right_context_length = right_context_length
|
||||||
if subsampling_factor != 4:
|
if subsampling_factor != 4:
|
||||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
if chunk_length % 4 != 0:
|
if chunk_length % subsampling_factor != 0:
|
||||||
raise NotImplementedError("chunk_length must be a mutiple of 4.")
|
|
||||||
if left_context_length != 0 and left_context_length % 4 != 0:
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"left_context_length must be 0 or a mutiple of 4."
|
"chunk_length must be a mutiple of subsampling_factor."
|
||||||
)
|
)
|
||||||
if right_context_length != 0 and right_context_length % 4 != 0:
|
if (
|
||||||
|
left_context_length != 0
|
||||||
|
and left_context_length % subsampling_factor != 0
|
||||||
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"right_context_length must be 0 or a mutiple of 4."
|
"left_context_length must be 0 or a mutiple of subsampling_factor." # noqa
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
right_context_length != 0
|
||||||
|
and right_context_length % subsampling_factor != 0
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"right_context_length must be 0 or a mutiple of subsampling_factor." # noqa
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||||
@ -1665,7 +1673,7 @@ class Emformer(EncoderInterface):
|
|||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
self.encoder = EmformerEncoder(
|
self.encoder = EmformerEncoder(
|
||||||
chunk_length=chunk_length // 4,
|
chunk_length=chunk_length // subsampling_factor,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
dim_feedforward=dim_feedforward,
|
dim_feedforward=dim_feedforward,
|
||||||
@ -1673,8 +1681,8 @@ class Emformer(EncoderInterface):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
layer_dropout=layer_dropout,
|
layer_dropout=layer_dropout,
|
||||||
cnn_module_kernel=cnn_module_kernel,
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
left_context_length=left_context_length // 4,
|
left_context_length=left_context_length // subsampling_factor,
|
||||||
right_context_length=right_context_length // 4,
|
right_context_length=right_context_length // subsampling_factor,
|
||||||
memory_size=memory_size,
|
memory_size=memory_size,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user