mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
small fixes
This commit is contained in:
parent
809d92274e
commit
6becd80e95
@ -303,7 +303,7 @@ class WenetSpeechAsrDataModule:
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=30000,
|
||||
buffer_size=300000,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -207,7 +207,7 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
@ -162,7 +162,7 @@ def get_parser():
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||
pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import k2.version
|
||||
@ -949,7 +949,7 @@ def tokenize_by_bpe_model(
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
sp: Optional[spm.SentencePieceProcessor] = None,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
@ -960,7 +960,7 @@ def display_and_save_batch(
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
Optional. The BPE model.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
@ -972,7 +972,12 @@ def display_and_save_batch(
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
text = supervisions["text"]
|
||||
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
if sp is not None:
|
||||
y = sp.encode(text, out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
else:
|
||||
num_tokens = sum(len(i) for i in text)
|
||||
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user