Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-05-23 18:18:04 +08:00
parent b7676ca1f2
commit 3e2dbc9ab5
3 changed files with 5 additions and 11 deletions

View File

@ -452,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
model_avg: nn.Module = None,
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
) -> Optional[Dict[str, Any]]:

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Tuple
import torch
@ -79,9 +78,8 @@ class LstmEncoder(EncoderInterface):
x = self.encoder_embed(x)
# Caution: We assume the subsampling factor is 4!
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = ((x_lens - 1) // 2 - 1) // 2
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(1) == lengths.max().item(), (
x.size(1),
lengths.max(),

View File

@ -20,11 +20,9 @@
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_lstm/test_model.py
python ./transducer_lstm/test_encoder.py
"""
import warnings
import torch
from train import get_encoder_model, get_params
@ -47,9 +45,7 @@ def test_encoder_model():
y, y_lens = encoder(x, x_lens)
print(y.shape)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
expected_y_lens = ((x_lens - 1) // 2 - 1) // 2
expected_y_lens = (((x_lens - 1) >> 1) - 1) >> 1
assert torch.all(torch.eq(y_lens, expected_y_lens)), (
y_lens,