support local_rank for multi-node

This commit is contained in:
root 2025-05-16 00:02:12 -07:00
parent 0e8c1db4d0
commit e52581e69b
3 changed files with 18 additions and 9 deletions

View File

@ -48,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from speech_dataset import K2SpeechRecognitionDataset from speech_dataset import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import get_rank, str2bool from utils import get_local_rank, str2bool
class _SeedWorkers: class _SeedWorkers:
@ -271,7 +271,7 @@ class AsrDataModule:
logging.info("Disable SpecAugment") logging.info("Disable SpecAugment")
logging.info("About to create train dataset") logging.info("About to create train dataset")
rank = get_rank() rank = get_local_rank()
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(
@ -331,7 +331,7 @@ class AsrDataModule:
CutSet for validation. CutSet for validation.
""" """
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
rank = get_rank() rank = get_local_rank()
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}")) WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))

View File

@ -75,6 +75,7 @@ from utils import ( # filter_uneven_sized_batch,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_rank, get_rank,
get_local_rank,
get_world_size, get_world_size,
setup_logger, setup_logger,
str2bool, str2bool,
@ -274,7 +275,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, "valid_interval": 1000,
# "env_info": get_env_info(), # "env_info": get_env_info(),
} }
) )
@ -844,7 +845,7 @@ def run(rank, world_size, args):
logging.info(f"{name}: {param.shape}") logging.info(f"{name}: {param.shape}")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", get_local_rank())
else: else:
device = torch.device("cpu") device = torch.device("cpu")
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
@ -867,10 +868,10 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 25.0: if c.duration < 0.8 or c.duration > 20.0:
logging.warning( # logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
) # )
return False return False
if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom: if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
codec_len = ( codec_len = (

View File

@ -38,6 +38,14 @@ def get_rank():
else: else:
return 0 return 0
def get_local_rank():
if "LOCAL_RANK" in os.environ:
return int(os.environ["LOCAL_RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.get_local_rank()
else:
return 0
def str2bool(v): def str2bool(v):
"""Used in argparse.ArgumentParser.add_argument to indicate """Used in argparse.ArgumentParser.add_argument to indicate
that a type is a bool type and user can enter that a type is a bool type and user can enter