diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 7252ee436..5d198f806 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -452,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, - model_avg: nn.Module = None, + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py index 296d0979f..ec00b0a7a 100644 --- a/egs/librispeech/ASR/transducer_lstm/encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/encoder.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import Tuple import torch @@ -79,9 +78,8 @@ class LstmEncoder(EncoderInterface): x = self.encoder_embed(x) # Caution: We assume the subsampling factor is 4! - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = ((x_lens - 1) // 2 - 1) // 2 + + lengths = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(1) == lengths.max().item(), ( x.size(1), lengths.max(), diff --git a/egs/librispeech/ASR/transducer_lstm/test_encoder.py b/egs/librispeech/ASR/transducer_lstm/test_encoder.py index 2689011a3..e0e2b2747 100755 --- a/egs/librispeech/ASR/transducer_lstm/test_encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/test_encoder.py @@ -20,11 +20,9 @@ To run this file, do: cd icefall/egs/librispeech/ASR - python ./transducer_lstm/test_model.py + python ./transducer_lstm/test_encoder.py """ -import warnings - import torch from train import get_encoder_model, get_params @@ -47,9 +45,7 @@ def test_encoder_model(): y, y_lens = encoder(x, x_lens) print(y.shape) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - expected_y_lens = ((x_lens - 1) // 2 - 1) // 2 + expected_y_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert torch.all(torch.eq(y_lens, expected_y_lens)), ( y_lens,