mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
support local_rank for multi-node
This commit is contained in:
parent
0e8c1db4d0
commit
e52581e69b
@ -48,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
from lhotse.utils import fix_random_seed
|
||||
from speech_dataset import K2SpeechRecognitionDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from utils import get_rank, str2bool
|
||||
from utils import get_local_rank, str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
@ -271,7 +271,7 @@ class AsrDataModule:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
rank = get_rank()
|
||||
rank = get_local_rank()
|
||||
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
@ -331,7 +331,7 @@ class AsrDataModule:
|
||||
CutSet for validation.
|
||||
"""
|
||||
logging.info("About to create dev dataset")
|
||||
rank = get_rank()
|
||||
rank = get_local_rank()
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
||||
|
@ -75,6 +75,7 @@ from utils import ( # filter_uneven_sized_batch,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_rank,
|
||||
get_local_rank,
|
||||
get_world_size,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -274,7 +275,7 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
"valid_interval": 1000,
|
||||
# "env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -844,7 +845,7 @@ def run(rank, world_size, args):
|
||||
logging.info(f"{name}: {param.shape}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
device = torch.device("cuda", get_local_rank())
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
@ -867,10 +868,10 @@ def run(rank, world_size, args):
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 25.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
if c.duration < 0.8 or c.duration > 20.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
|
||||
codec_len = (
|
||||
|
@ -38,6 +38,14 @@ def get_rank():
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_local_rank():
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
return int(os.environ["LOCAL_RANK"])
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
return dist.get_local_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def str2bool(v):
|
||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||
that a type is a bool type and user can enter
|
||||
|
Loading…
x
Reference in New Issue
Block a user