From 64c9053e3960ffd5165f599ddeaab65a11158c3c Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 11 Jun 2022 22:40:40 +0800 Subject: [PATCH] minor fix --- .../emformer.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 7d7def879..39fc04186 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1646,15 +1646,23 @@ class Emformer(EncoderInterface): self.right_context_length = right_context_length if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - if chunk_length % 4 != 0: - raise NotImplementedError("chunk_length must be a mutiple of 4.") - if left_context_length != 0 and left_context_length % 4 != 0: + if chunk_length % subsampling_factor != 0: raise NotImplementedError( - "left_context_length must be 0 or a mutiple of 4." + "chunk_length must be a mutiple of subsampling_factor." ) - if right_context_length != 0 and right_context_length % 4 != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( - "right_context_length must be 0 or a mutiple of 4." + "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa + ) + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): + raise NotImplementedError( + "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -1665,7 +1673,7 @@ class Emformer(EncoderInterface): self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder = EmformerEncoder( - chunk_length=chunk_length // 4, + chunk_length=chunk_length // subsampling_factor, d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, @@ -1673,8 +1681,8 @@ class Emformer(EncoderInterface): dropout=dropout, layer_dropout=layer_dropout, cnn_module_kernel=cnn_module_kernel, - left_context_length=left_context_length // 4, - right_context_length=right_context_length // 4, + left_context_length=left_context_length // subsampling_factor, + right_context_length=right_context_length // subsampling_factor, memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf,