From d11e01e190d70a74ca48540c1db3aef34764323a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 May 2022 22:06:34 +0800 Subject: [PATCH] Minor fixes. --- .../ASR/transducer_lstm/encoder.py | 34 +++------- .../ASR/transducer_lstm/test_encoder.py | 65 +++++++++++++++++++ .../ASR/transducer_lstm/test_model.py | 2 +- egs/librispeech/ASR/transducer_lstm/train.py | 8 +-- 4 files changed, 80 insertions(+), 29 deletions(-) create mode 100755 egs/librispeech/ASR/transducer_lstm/test_encoder.py diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py index 50c31275c..7610019b2 100644 --- a/egs/librispeech/ASR/transducer_lstm/encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/encoder.py @@ -30,13 +30,11 @@ class LstmEncoder(EncoderInterface): hidden_size: int, output_dim: int, subsampling_factor: int = 4, - num_encoder_layers: int = 12, + num_encoder_layers: int = 6, dropout: float = 0.1, vgg_frontend: bool = False, - proj_size: int = 0, ): super().__init__() - real_hidden_size = proj_size if proj_size > 0 else hidden_size assert ( subsampling_factor == 4 ), "Only subsampling_factor==4 is supported at present" @@ -47,28 +45,21 @@ class LstmEncoder(EncoderInterface): # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, real_hidden_size) + self.encoder_embed = VggSubsampling(num_features, output_dim) else: - self.encoder_embed = Conv2dSubsampling( - num_features, real_hidden_size - ) + self.encoder_embed = Conv2dSubsampling(num_features, output_dim) self.rnn = nn.LSTM( - input_size=hidden_size, + input_size=output_dim, hidden_size=hidden_size, num_layers=num_encoder_layers, bias=True, - proj_size=proj_size, + proj_size=output_dim, batch_first=True, dropout=dropout, bidirectional=False, ) - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), - nn.Linear(real_hidden_size, output_dim), - ) - def forward( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -96,23 +87,18 @@ class LstmEncoder(EncoderInterface): lengths.max(), ) - if False: - # It is commented out as DPP complains that not all parameters are - # used. Need more checks later for the reason. - # - # Caution: We assume the dataloader returns utterances with - # duration being sorted in decreasing order + if True: + # This branch is more efficient than the else branch packed_x = pack_padded_sequence( input=x, lengths=lengths.cpu(), batch_first=True, - enforce_sorted=True, + enforce_sorted=False, ) packed_rnn_out, _ = self.rnn(packed_x) - rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True) + rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True) else: rnn_out, _ = self.rnn(x) - logits = self.encoder_output_layer(rnn_out) - return logits, lengths + return rnn_out, lengths diff --git a/egs/librispeech/ASR/transducer_lstm/test_encoder.py b/egs/librispeech/ASR/transducer_lstm/test_encoder.py new file mode 100755 index 000000000..2689011a3 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/test_encoder.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_lstm/test_model.py +""" + +import warnings + +import torch +from train import get_encoder_model, get_params + + +def test_encoder_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + encoder = get_encoder_model(params) + num_param = sum([p.numel() for p in encoder.parameters()]) + print(f"Number of encoder model parameters: {num_param}") + + N = 3 + T = 500 + C = 80 + + x = torch.rand(N, T, C) + x_lens = torch.tensor([100, 500, 300]) + + y, y_lens = encoder(x, x_lens) + print(y.shape) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + expected_y_lens = ((x_lens - 1) // 2 - 1) // 2 + + assert torch.all(torch.eq(y_lens, expected_y_lens)), ( + y_lens, + expected_y_lens, + ) + + +def main(): + test_encoder_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_lstm/test_model.py b/egs/librispeech/ASR/transducer_lstm/test_model.py index acd71455d..071671f27 100755 --- a/egs/librispeech/ASR/transducer_lstm/test_model.py +++ b/egs/librispeech/ASR/transducer_lstm/test_model.py @@ -20,7 +20,7 @@ To run this file, do: cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless4/test_model.py + python ./transducer_lstm/test_model.py """ from train import get_params, get_transducer_model diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 2d520f230..269f631f0 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -42,7 +42,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" """ - import argparse import logging import warnings @@ -339,9 +338,9 @@ def get_params() -> AttributeDict: "feature_dim": 80, "subsampling_factor": 4, "encoder_dim": 512, - "encoder_hidden_size": 1024, - "num_encoder_layers": 4, - "proj_size": 512, + "encoder_hidden_size": 2048, + "num_encoder_layers": 6, + "dropout": 0.1, "vgg_frontend": False, # parameters for decoder "decoder_dim": 512, @@ -363,6 +362,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: output_dim=params.encoder_dim, subsampling_factor=params.subsampling_factor, num_encoder_layers=params.num_encoder_layers, + dropout=params.dropout, vgg_frontend=params.vgg_frontend, ) return encoder