mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
support local_rank for multi-node
This commit is contained in:
parent
0e8c1db4d0
commit
e52581e69b
@ -48,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from speech_dataset import K2SpeechRecognitionDataset
|
from speech_dataset import K2SpeechRecognitionDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from utils import get_rank, str2bool
|
from utils import get_local_rank, str2bool
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
@ -271,7 +271,7 @@ class AsrDataModule:
|
|||||||
logging.info("Disable SpecAugment")
|
logging.info("Disable SpecAugment")
|
||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
rank = get_rank()
|
rank = get_local_rank()
|
||||||
|
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(
|
||||||
@ -331,7 +331,7 @@ class AsrDataModule:
|
|||||||
CutSet for validation.
|
CutSet for validation.
|
||||||
"""
|
"""
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
rank = get_rank()
|
rank = get_local_rank()
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(
|
||||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
||||||
|
@ -75,6 +75,7 @@ from utils import ( # filter_uneven_sized_batch,
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
get_rank,
|
get_rank,
|
||||||
|
get_local_rank,
|
||||||
get_world_size,
|
get_world_size,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -274,7 +275,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000,
|
"valid_interval": 1000,
|
||||||
# "env_info": get_env_info(),
|
# "env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -844,7 +845,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"{name}: {param.shape}")
|
logging.info(f"{name}: {param.shape}")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", get_local_rank())
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
@ -867,10 +868,10 @@ def run(rank, world_size, args):
|
|||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 1.0 or c.duration > 25.0:
|
if c.duration < 0.8 or c.duration > 20.0:
|
||||||
logging.warning(
|
# logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
)
|
# )
|
||||||
return False
|
return False
|
||||||
if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
|
if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
|
||||||
codec_len = (
|
codec_len = (
|
||||||
|
@ -38,6 +38,14 @@ def get_rank():
|
|||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def get_local_rank():
|
||||||
|
if "LOCAL_RANK" in os.environ:
|
||||||
|
return int(os.environ["LOCAL_RANK"])
|
||||||
|
elif dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.get_local_rank()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||||
that a type is a bool type and user can enter
|
that a type is a bool type and user can enter
|
||||||
|
Loading…
x
Reference in New Issue
Block a user