mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Disable CUDA_LAUNCH_BLOCKING in wenetspeech recipes. (#554)
* Disable CUDA_LAUNCH_BLOCKING in wenetspeech recipes. * minor fixes
This commit is contained in:
parent
235eb0746f
commit
d68b8e9120
@ -23,6 +23,8 @@ from pathlib import Path
|
|||||||
from lhotse import CutSet, SupervisionSegment
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall import setup_logger
|
||||||
|
|
||||||
# Similar text filtering and normalization procedure as in:
|
# Similar text filtering and normalization procedure as in:
|
||||||
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
|
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
|
||||||
|
|
||||||
@ -48,13 +50,17 @@ def preprocess_wenet_speech():
|
|||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
output_dir.mkdir(exist_ok=True)
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Note: By default, we preprocess all sub-parts.
|
||||||
|
# You can delete those that you don't need.
|
||||||
|
# For instance, if you don't want to use the L subpart, just remove
|
||||||
|
# the line below containing "L"
|
||||||
dataset_parts = (
|
dataset_parts = (
|
||||||
"L",
|
|
||||||
"M",
|
|
||||||
"S",
|
|
||||||
"DEV",
|
"DEV",
|
||||||
"TEST_NET",
|
"TEST_NET",
|
||||||
"TEST_MEETING",
|
"TEST_MEETING",
|
||||||
|
"S",
|
||||||
|
"M",
|
||||||
|
"L",
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Loading manifest (may take 10 minutes)")
|
logging.info("Loading manifest (may take 10 minutes)")
|
||||||
@ -81,10 +87,13 @@ def preprocess_wenet_speech():
|
|||||||
logging.info(f"Normalizing text in {partition}")
|
logging.info(f"Normalizing text in {partition}")
|
||||||
for sup in m["supervisions"]:
|
for sup in m["supervisions"]:
|
||||||
text = str(sup.text)
|
text = str(sup.text)
|
||||||
logging.info(f"Original text: {text}")
|
orig_text = text
|
||||||
sup.text = normalize_text(sup.text)
|
sup.text = normalize_text(sup.text)
|
||||||
text = str(sup.text)
|
text = str(sup.text)
|
||||||
logging.info(f"Normalize text: {text}")
|
if len(orig_text) != len(text):
|
||||||
|
logging.info(
|
||||||
|
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create long-recording cut manifests.
|
# Create long-recording cut manifests.
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
@ -109,12 +118,10 @@ def preprocess_wenet_speech():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
formatter = (
|
setup_logger(log_filename="./log-preprocess-wenetspeech")
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
|
||||||
preprocess_wenet_speech()
|
preprocess_wenet_speech()
|
||||||
|
logging.info("Done")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -81,7 +81,6 @@ For training with the S subset:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -120,8 +119,6 @@ LRSchedulerType = Union[
|
|||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
]
|
]
|
||||||
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -162,7 +159,7 @@ def get_parser():
|
|||||||
default=0,
|
default=0,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from from this epoch.
|
||||||
If it is positive, it will load checkpoint from
|
If it is positive, it will load checkpoint from
|
||||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -361,8 +358,8 @@ def get_params() -> AttributeDict:
|
|||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 10,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 1,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
@ -545,7 +542,7 @@ def compute_loss(
|
|||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute RNN-T loss given the model and its inputs.
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
@ -573,7 +570,7 @@ def compute_loss(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
y = graph_compiler.texts_to_ids(texts)
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
if type(y) == list:
|
if isinstance(y, list):
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
else:
|
else:
|
||||||
y = y.to(device)
|
y = y.to(device)
|
||||||
@ -697,7 +694,6 @@ def train_one_epoch(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
@ -61,7 +61,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -103,8 +102,6 @@ LRSchedulerType = Union[
|
|||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
]
|
]
|
||||||
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -684,7 +681,7 @@ def compute_loss(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
y = graph_compiler.texts_to_ids(texts)
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
if type(y) == list:
|
if isinstance(y, list):
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
else:
|
else:
|
||||||
y = y.to(device)
|
y = y.to(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user