Skip to content

Commit

Permalink
Updated code
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Sep 25, 2024
1 parent bb5b117 commit f60c185
Showing 1 changed file with 50 additions and 50 deletions.
100 changes: 50 additions & 50 deletions ray_provider/decorators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,54 +31,23 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):

template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs")

def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None)
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get(
"gpu_device_plugin_yaml",
"https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)
def __init__(self, config: dict[str, Any] | Callable[[Context], dict[str, Any]], **kwargs: Any) -> None:

self.config = config
self.kwargs = kwargs

if not isinstance(self.num_cpus, (int, float)):
raise TypeError("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")

super().__init__(
conn_id=self.conn_id,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
num_cpus=self.num_cpus,
num_gpus=self.num_gpus,
memory=self.memory,
resources=self.ray_resources,
ray_cluster_yaml=self.ray_cluster_yaml,
update_if_exists=self.update_if_exists,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
fetch_logs=self.fetch_logs,
wait_for_completion=self.wait_for_completion,
job_timeout_seconds=self.job_timeout_seconds,
poll_interval=self.poll_interval,
xcom_task_key=self.xcom_task_key,
**kwargs,
)
super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs)

def get_config(self, context: Context, config: Callable[..., dict[str, Any]], **kwargs: Any) -> dict[str, Any]:
config_params = inspect.signature(config).parameters

config_kwargs = {k: v for k, v in kwargs.items() if k in config_params and k != "context"}

if "context" in config_params:
config_kwargs["context"] = context

# Call config with the prepared arguments
return config(**config_kwargs)

def execute(self, context: Context) -> Any:
"""
Expand All @@ -90,6 +59,40 @@ def execute(self, context: Context) -> Any:
"""
temp_dir = None
try:
# Generate the configuration
if callable(self.config):
config = self.get_config(context=context, config=self.config, **self.kwargs)
else:
config = self.config

# Prepare Ray job parameters
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None)
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get(
"gpu_device_plugin_yaml",
"https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)

if not isinstance(self.num_cpus, (int, float)):
raise TypeError("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")

if self.is_decorated_function:
self.log.info(
f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}"
Expand Down Expand Up @@ -159,14 +162,11 @@ def task(
"""
if config is None:
config = {}
elif callable(config):
config = config(**kwargs)
elif not isinstance(config, dict):
raise TypeError("config must be either a callable, a dictionary, or None")

return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
config=config,
**kwargs,
)

0 comments on commit f60c185

Please sign in to comment.