From 6becd80e9509d773f1815766c04b0b5c89481ee3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 30 Sep 2022 16:21:56 +0800 Subject: [PATCH] small fixes --- .../asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/decode.py | 2 +- .../ASR/pruned_transducer_stateless2/train.py | 2 +- icefall/utils.py | 15 ++++++++++----- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 10c953e3b..7f277a9fb 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -303,7 +303,7 @@ class WenetSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=30000, + buffer_size=300000, drop_last=True, ) else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index bbd8680b2..370bfdfbd 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -207,7 +207,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="""An interger indicating how many candidates we will keep for each + help="""An integer indicating how many candidates we will keep for each frame. Used only when --decoding-method is beam_search or modified_beam_search.""", ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 5208dbefe..cff9e7d21 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -162,7 +162,7 @@ def get_parser(): default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt + pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt """, ) diff --git a/icefall/utils.py b/icefall/utils.py index ad079222e..a044f618a 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -26,7 +26,7 @@ from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union import k2 import k2.version @@ -949,7 +949,7 @@ def tokenize_by_bpe_model( def display_and_save_batch( batch: dict, params: AttributeDict, - sp: spm.SentencePieceProcessor, + sp: Optional[spm.SentencePieceProcessor] = None, ) -> None: """Display the batch statistics and save the batch into disk. @@ -960,7 +960,7 @@ def display_and_save_batch( params: Parameters for training. See :func:`get_params`. sp: - The BPE model. + Optional. The BPE model. """ from lhotse.utils import uuid4 @@ -972,7 +972,12 @@ def display_and_save_batch( features = batch["inputs"] logging.info(f"features shape: {features.shape}") + text = supervisions["text"] + + if sp is not None: + y = sp.encode(text, out_type=int) + num_tokens = sum(len(i) for i in y) + else: + num_tokens = sum(len(i) for i in text) - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}")