mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
small fixes
This commit is contained in:
parent
809d92274e
commit
6becd80e95
@ -303,7 +303,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
buffer_size=30000,
|
buffer_size=300000,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -207,7 +207,7 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An interger indicating how many candidates we will keep for each
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
frame. Used only when --decoding-method is beam_search or
|
frame. Used only when --decoding-method is beam_search or
|
||||||
modified_beam_search.""",
|
modified_beam_search.""",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -162,7 +162,7 @@ def get_parser():
|
|||||||
default=0,
|
default=0,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from from this epoch.
|
||||||
If it is positive, it will load checkpoint from
|
If it is positive, it will load checkpoint from
|
||||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from collections import defaultdict
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import k2.version
|
import k2.version
|
||||||
@ -949,7 +949,7 @@ def tokenize_by_bpe_model(
|
|||||||
def display_and_save_batch(
|
def display_and_save_batch(
|
||||||
batch: dict,
|
batch: dict,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: Optional[spm.SentencePieceProcessor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Display the batch statistics and save the batch into disk.
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
@ -960,7 +960,7 @@ def display_and_save_batch(
|
|||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
Optional. The BPE model.
|
||||||
"""
|
"""
|
||||||
from lhotse.utils import uuid4
|
from lhotse.utils import uuid4
|
||||||
|
|
||||||
@ -972,7 +972,12 @@ def display_and_save_batch(
|
|||||||
features = batch["inputs"]
|
features = batch["inputs"]
|
||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
text = supervisions["text"]
|
||||||
|
|
||||||
y = sp.encode(supervisions["text"], out_type=int)
|
if sp is not None:
|
||||||
|
y = sp.encode(text, out_type=int)
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
|
else:
|
||||||
|
num_tokens = sum(len(i) for i in text)
|
||||||
|
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user