Skip to content

Commit

Permalink
Worker get elastic run config from master (#1277)
Browse files Browse the repository at this point in the history
* worker get elastic run config from master

* worker get elastic run config from master

* integrate training node configure

* integrate training node configure

* integrate training node configure

* fix comments

---------

Co-authored-by: bsang <[email protected]>
  • Loading branch information
samplise and bsang authored Oct 8, 2024
1 parent 6764a09 commit f03b769
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 7 deletions.
10 changes: 10 additions & 0 deletions dlrover/python/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,13 @@ class DiagnosisChipMetrics(Message):
class SyncTrainingPort(Message):
port: int = 0
newport: int = 0


@dataclass
class ElasticRunConfigRequest(Message):
pass


@dataclass
class ElasticRunConfig(Message):
configs: Dict[str, str] = field(default_factory=dict)
6 changes: 6 additions & 0 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import threading
import time
from contextlib import closing
from typing import Dict

from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc
from dlrover.python.common import env_utils, grpc
Expand Down Expand Up @@ -422,6 +423,11 @@ def sync_training_ports(self, port) -> grpc.SyncTrainingPort:
response: grpc.SyncTrainingPort = self._get(request)
return response

def get_elastic_run_config(self) -> Dict[str, str]:
request = grpc.ElasticRunConfigRequest()
response: grpc.ElasticRunConfig = self._get(request)
return response.configs

@classmethod
def singleton_instance(cls, *args, **kwargs):
if not cls._instance:
Expand Down
10 changes: 8 additions & 2 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,14 @@ def __init__(
node_watcher: Optional[NodeWatcher] = None,
job_scaler=None,
error_monitor=None,
external_config=None,
):
super().__init__(job_args, speed_monitor, error_monitor)
super().__init__(
job_args=job_args,
speed_monitor=speed_monitor,
error_monitor=error_monitor,
external_config=external_config,
)
self._remove_exited_node = job_args.remove_exited_node
node_restart_count: Dict[str, int] = {}
for type, node_args in job_args.node_args.items():
Expand Down Expand Up @@ -199,7 +205,7 @@ def start(self):
if NodeType.CHIEF in plan.node_group_resources:
worker_num += plan.node_group_resources[NodeType.CHIEF].count
self._speed_monitor.set_target_worker_num(worker_num)
self._training_node_configure.set_node_num(worker_num)
self._training_node_config.set_node_num(worker_num)
threading.Thread(
target=self._monitor_nodes, name="node_monitor", daemon=True
).start()
Expand Down
10 changes: 7 additions & 3 deletions dlrover/python/master/node/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from dlrover.python.master.monitor.speed_monitor import SpeedMonitor
from dlrover.python.master.node.training_node import (
SyncNodeTrainingPorts,
TrainingNodeConfigure,
TrainingNodeConfig,
)
from dlrover.python.master.resource.job import JobResource
from dlrover.python.scheduler.job import JobArgs
Expand All @@ -40,6 +40,7 @@ def __init__(
job_args: JobArgs,
speed_monitor=None,
error_monitor=None,
external_config=None,
):
self._job_resource = JobResource()
self._job_args = job_args
Expand All @@ -55,7 +56,7 @@ def __init__(
self._job_nodes: Dict[str, Dict[int, Node]] = {}
self._nodes_required = (0, 0, 0)

self._training_node_configure = TrainingNodeConfigure()
self._training_node_config = TrainingNodeConfig(external_config)

@abstractmethod
def start(self):
Expand Down Expand Up @@ -199,7 +200,7 @@ def collect_node_heart_beat(self, node_type, node_id, timestamp):
pass

def sync_node_training_port(self, node_id, port) -> SyncNodeTrainingPorts:
return self._training_node_configure.sync_node_training_port(
return self._training_node_config.sync_node_training_port(
node_id, port
)

Expand Down Expand Up @@ -227,3 +228,6 @@ def update_node_required_info_callback(self):
"""Callback when 'update_node_required_info' is invoked."""

pass

def get_elastic_run_configs(self) -> Dict[str, str]:
return self._training_node_config.get_elastic_run_configs()
2 changes: 1 addition & 1 deletion dlrover/python/master/node/local_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
def start(self):
self._job_nodes[NodeType.WORKER] = {}
worker = self._job_args.node_args[NodeType.WORKER].group_resource
self._training_node_configure.set_node_num(worker.count)
self._training_node_config.set_node_num(worker.count)
for i in range(worker.count):
self._job_nodes[NodeType.WORKER][i] = Node(
name=NodeType.WORKER + f"-{i}",
Expand Down
18 changes: 17 additions & 1 deletion dlrover/python/master/node/training_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
import threading
import time
from abc import ABCMeta, abstractmethod
from collections import Counter
from dataclasses import dataclass
from threading import Lock
Expand Down Expand Up @@ -400,13 +401,23 @@ class SyncNodeTrainingPorts:
next_check_port: int = 0


class TrainingNodeConfigure:
class ExternalConfig(metaclass=ABCMeta):
def __init__(self):
pass

@abstractmethod
def get_elastic_run_configs(self) -> Dict[str, str]:
pass


class TrainingNodeConfig:
def __init__(self, external_config: ExternalConfig = None):
self._lock = Lock()
self._recv_node_training_ports: Dict[int, int] = {}
self._node_training_port = 0
self._next_check_node_training_port = 0
self._n_node = 0
self._external_config = external_config

def set_node_num(self, num):
logger.info(f"set worker count: {num}")
Expand Down Expand Up @@ -457,3 +468,8 @@ def sync_node_training_port(self, node_id, port) -> SyncNodeTrainingPorts:
return SyncNodeTrainingPorts(
training_port=0, next_check_port=0
)

def get_elastic_run_configs(self) -> Dict[str, str]:
if not self._external_config:
return {}
return self._external_config.get_elastic_run_configs()
3 changes: 3 additions & 0 deletions dlrover/python/master/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def get(self, request, _):
message = self._need_to_restart_training(node_type, node_id)
elif isinstance(req_message, grpc.SyncTrainingPort):
message = self._sync_training_ports(node_id, req_message)
elif isinstance(req_message, grpc.ElasticRunConfigRequest):
configs = self._job_manager.get_elastic_run_configs()
message = grpc.ElasticRunConfig(configs=configs)

if message:
response.data = message.serialize()
Expand Down
35 changes: 35 additions & 0 deletions dlrover/trainer/tests/torch/elastic_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
import unittest
from unittest.mock import patch

from dlrover.python.elastic_agent.master_client import (
MasterClient,
build_master_client,
)
from dlrover.python.tests.test_utils import start_local_master
from dlrover.trainer.torch.elastic_run import (
_check_dlrover_master_available,
_check_to_use_dlrover_run,
Expand Down Expand Up @@ -56,6 +61,8 @@ def test_check_to_use_dlrover_run(self):
_check_to_use_dlrover_run("127.0.0.1:12345", 2, 3)

def test_elastic_config_from_args(self):
self._master, addr = start_local_master()
MasterClient._instance = build_master_client(addr, 1)
args = [
"--network_check",
"--comm_perf_test",
Expand All @@ -80,3 +87,31 @@ def test_elastic_config_from_args(self):
self.assertEqual(config.training_port, 1000)
self.assertEqual(cmd, "/usr/local/bin/python")
self.assertListEqual(cmd_args, ["-u", "test.py", "--batch_size", "16"])

@patch(f"{MC_PATH}.get_elastic_run_config")
def test_elastic_config_from_master(self, mock_func):
self._master, addr = start_local_master()
MasterClient._instance = build_master_client(addr, 1)
mock_func.return_value = {
"network_check": "True",
"comm_perf_test": "True",
"auto_tunning": "True",
"auto_config": "True",
"exclude_straggler": "True",
"save_at_breakpoint": "True",
}
args = [
"--training_port",
"1000",
"test.py",
"--batch_size",
"16",
]
args = parse_args(args)
config, cmd, cmd_args = _elastic_config_from_args(args)
self.assertTrue(config.network_check)
self.assertTrue(config.comm_perf_test)
self.assertTrue(config.auto_tunning)
self.assertTrue(config.auto_config)
self.assertTrue(config.exclude_straggler)
self.assertTrue(config.save_at_breakpoint)
53 changes: 53 additions & 0 deletions dlrover/trainer/torch/elastic_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,26 +304,46 @@ def _elastic_config_from_args(
args,
) -> Tuple[ElasticLaunchConfig, Union[Callable, str], List[str]]:
config, cmd, cmd_args = config_from_args(args)

master_config = _elastic_config_from_master(config)
elastic_config = ElasticLaunchConfig(**config.__dict__)

# PyTorch >= 2.3.0 remove log_dir in the LaunchConfig.
if not version_less_than_230():
elastic_config.log_dir = config.logs_specs.root_log_dir

elastic_config.network_check = getattr(args, "network_check", False)
if master_config.network_check:
elastic_config.network_check = True

elastic_config.comm_perf_test = getattr(args, "comm_perf_test", False)
if master_config.comm_perf_test:
elastic_config.comm_perf_test = True

elastic_config.auto_tunning = getattr(args, "auto_tunning", False)
if master_config.auto_tunning:
elastic_config.auto_tunning = True

elastic_config.auto_config = getattr(args, "auto_config", False)
if master_config.auto_config:
elastic_config.auto_config = True

elastic_config.accelerator = getattr(
args, "accelerator", Accelerators.NVIDIA_GPU
)

elastic_config.exclude_straggler = getattr(
args, "exclude_straggler", False
)
if master_config.exclude_straggler:
elastic_config.exclude_straggler = True
elastic_config.set_node_unit(getattr(args, "node_unit", 1))
elastic_config.training_port = getattr(args, "training_port", 60000)
elastic_config.save_at_breakpoint = getattr(
args, "save_at_breakpoint", False
)
if master_config.save_at_breakpoint:
elastic_config.save_at_breakpoint = True
elastic_config.auto_configure_params()
elastic_config.rdzv_backend = "dlrover-master"
elastic_config.rdzv_endpoint = ""
Expand All @@ -332,6 +352,39 @@ def _elastic_config_from_args(
return elastic_config, cmd, cmd_args


def _elastic_config_from_master(config) -> ElasticLaunchConfig:
elastic_config = ElasticLaunchConfig(**config.__dict__)

_client = MasterClient.singleton_instance()
master_configs = _client.get_elastic_run_config()

elastic_config.network_check = False
if "network_check" in master_configs:
elastic_config.network_check = True

elastic_config.comm_perf_test = False
if "comm_perf_test" in master_configs:
elastic_config.comm_perf_test = True

elastic_config.auto_tunning = False
if "auto_tunning" in master_configs:
elastic_config.auto_tunning = True

elastic_config.auto_config = False
if "auto_config" in master_configs:
elastic_config.auto_config = True

elastic_config.exclude_straggler = False
if "exclude_straggler" in master_configs:
elastic_config.exclude_straggler = True

elastic_config.save_at_breakpoint = False
if "save_at_breakpoint" in master_configs:
elastic_config.save_at_breakpoint = True

return elastic_config


def _check_to_use_dlrover_run(master_addr, max_nodes, timeout=120):
if _check_dlrover_master_available(master_addr, timeout):
return True
Expand Down

0 comments on commit f03b769

Please sign in to comment.