small fixes

This commit is contained in:
Fangjun Kuang 2022-09-30 16:21:56 +08:00
parent 809d92274e
commit 6becd80e95
4 changed files with 13 additions and 8 deletions

View File

@ -303,7 +303,7 @@ class WenetSpeechAsrDataModule:
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
buffer_size=30000, buffer_size=300000,
drop_last=True, drop_last=True,
) )
else: else:

View File

@ -207,7 +207,7 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or frame. Used only when --decoding-method is beam_search or
modified_beam_search.""", modified_beam_search.""",
) )

View File

@ -162,7 +162,7 @@ def get_parser():
default=0, default=0,
help="""Resume training from from this epoch. help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from If it is positive, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
""", """,
) )

View File

@ -26,7 +26,7 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
import k2 import k2
import k2.version import k2.version
@ -949,7 +949,7 @@ def tokenize_by_bpe_model(
def display_and_save_batch( def display_and_save_batch(
batch: dict, batch: dict,
params: AttributeDict, params: AttributeDict,
sp: spm.SentencePieceProcessor, sp: Optional[spm.SentencePieceProcessor] = None,
) -> None: ) -> None:
"""Display the batch statistics and save the batch into disk. """Display the batch statistics and save the batch into disk.
@ -960,7 +960,7 @@ def display_and_save_batch(
params: params:
Parameters for training. See :func:`get_params`. Parameters for training. See :func:`get_params`.
sp: sp:
The BPE model. Optional. The BPE model.
""" """
from lhotse.utils import uuid4 from lhotse.utils import uuid4
@ -972,7 +972,12 @@ def display_and_save_batch(
features = batch["inputs"] features = batch["inputs"]
logging.info(f"features shape: {features.shape}") logging.info(f"features shape: {features.shape}")
text = supervisions["text"]
y = sp.encode(supervisions["text"], out_type=int) if sp is not None:
y = sp.encode(text, out_type=int)
num_tokens = sum(len(i) for i in y) num_tokens = sum(len(i) for i in y)
else:
num_tokens = sum(len(i) for i in text)
logging.info(f"num tokens: {num_tokens}") logging.info(f"num tokens: {num_tokens}")