diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index 72f26a803..da337791a 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -48,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from speech_dataset import K2SpeechRecognitionDataset from torch.utils.data import DataLoader -from utils import get_rank, str2bool +from utils import get_local_rank, str2bool class _SeedWorkers: @@ -271,7 +271,7 @@ class AsrDataModule: logging.info("Disable SpecAugment") logging.info("About to create train dataset") - rank = get_rank() + rank = get_local_rank() train = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures( @@ -331,7 +331,7 @@ class AsrDataModule: CutSet for validation. """ logging.info("About to create dev dataset") - rank = get_rank() + rank = get_local_rank() validate = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures( WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}")) diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index c523c92a5..a11ae4b76 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -75,6 +75,7 @@ from utils import ( # filter_uneven_sized_batch, AttributeDict, MetricsTracker, get_rank, + get_local_rank, get_world_size, setup_logger, str2bool, @@ -274,7 +275,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, + "valid_interval": 1000, # "env_info": get_env_info(), } ) @@ -844,7 +845,7 @@ def run(rank, world_size, args): logging.info(f"{name}: {param.shape}") if torch.cuda.is_available(): - device = torch.device("cuda", rank) + device = torch.device("cuda", get_local_rank()) else: device = torch.device("cpu") logging.info(f"Device: {device}") @@ -867,10 +868,10 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 25.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) + if c.duration < 0.8 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) return False if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom: codec_len = ( diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py index fe65a8042..7c6f6c0a6 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py @@ -38,6 +38,14 @@ def get_rank(): else: return 0 +def get_local_rank(): + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + elif dist.is_available() and dist.is_initialized(): + return dist.get_local_rank() + else: + return 0 + def str2bool(v): """Used in argparse.ArgumentParser.add_argument to indicate that a type is a bool type and user can enter