diff --git a/ray_provider/decorators/ray.py b/ray_provider/decorators/ray.py index 38d84e0..7acdff5 100644 --- a/ray_provider/decorators/ray.py +++ b/ray_provider/decorators/ray.py @@ -31,54 +31,23 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob): template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs") - def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: - self.conn_id: str = config.get("conn_id", "") - self.is_decorated_function = False if "entrypoint" in config else True - self.entrypoint: str = config.get("entrypoint", "python script.py") - self.runtime_env: dict[str, Any] = config.get("runtime_env", {}) - - self.num_cpus: int | float = config.get("num_cpus", 1) - self.num_gpus: int | float = config.get("num_gpus", 0) - self.memory: int | float = config.get("memory", None) - self.ray_resources: dict[str, Any] | None = config.get("resources", None) - self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None) - self.update_if_exists: bool = config.get("update_if_exists", False) - self.kuberay_version: str = config.get("kuberay_version", "1.0.0") - self.gpu_device_plugin_yaml: str = config.get( - "gpu_device_plugin_yaml", - "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", - ) - self.fetch_logs: bool = config.get("fetch_logs", True) - self.wait_for_completion: bool = config.get("wait_for_completion", True) - self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600) - self.poll_interval: int = config.get("poll_interval", 60) - self.xcom_task_key: str | None = config.get("xcom_task_key", None) + def __init__(self, config: dict[str, Any] | Callable[[Context], dict[str, Any]], **kwargs: Any) -> None: + self.config = config + self.kwargs = kwargs - if not isinstance(self.num_cpus, (int, float)): - raise TypeError("num_cpus should be an integer or float value") - if not isinstance(self.num_gpus, (int, float)): - raise TypeError("num_gpus should be an integer or float value") - - super().__init__( - conn_id=self.conn_id, - entrypoint=self.entrypoint, - runtime_env=self.runtime_env, - num_cpus=self.num_cpus, - num_gpus=self.num_gpus, - memory=self.memory, - resources=self.ray_resources, - ray_cluster_yaml=self.ray_cluster_yaml, - update_if_exists=self.update_if_exists, - kuberay_version=self.kuberay_version, - gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, - fetch_logs=self.fetch_logs, - wait_for_completion=self.wait_for_completion, - job_timeout_seconds=self.job_timeout_seconds, - poll_interval=self.poll_interval, - xcom_task_key=self.xcom_task_key, - **kwargs, - ) + super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs) + + def get_config(self, context: Context, config: Callable[..., dict[str, Any]], **kwargs: Any) -> dict[str, Any]: + config_params = inspect.signature(config).parameters + + config_kwargs = {k: v for k, v in kwargs.items() if k in config_params and k != "context"} + + if "context" in config_params: + config_kwargs["context"] = context + + # Call config with the prepared arguments + return config(**config_kwargs) def execute(self, context: Context) -> Any: """ @@ -90,6 +59,40 @@ def execute(self, context: Context) -> Any: """ temp_dir = None try: + # Generate the configuration + if callable(self.config): + config = self.get_config(context=context, config=self.config, **self.kwargs) + else: + config = self.config + + # Prepare Ray job parameters + self.conn_id: str = config.get("conn_id", "") + self.is_decorated_function = False if "entrypoint" in config else True + self.entrypoint: str = config.get("entrypoint", "python script.py") + self.runtime_env: dict[str, Any] = config.get("runtime_env", {}) + + self.num_cpus: int | float = config.get("num_cpus", 1) + self.num_gpus: int | float = config.get("num_gpus", 0) + self.memory: int | float = config.get("memory", None) + self.ray_resources: dict[str, Any] | None = config.get("resources", None) + self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None) + self.update_if_exists: bool = config.get("update_if_exists", False) + self.kuberay_version: str = config.get("kuberay_version", "1.0.0") + self.gpu_device_plugin_yaml: str = config.get( + "gpu_device_plugin_yaml", + "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", + ) + self.fetch_logs: bool = config.get("fetch_logs", True) + self.wait_for_completion: bool = config.get("wait_for_completion", True) + self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600) + self.poll_interval: int = config.get("poll_interval", 60) + self.xcom_task_key: str | None = config.get("xcom_task_key", None) + + if not isinstance(self.num_cpus, (int, float)): + raise TypeError("num_cpus should be an integer or float value") + if not isinstance(self.num_gpus, (int, float)): + raise TypeError("num_gpus should be an integer or float value") + if self.is_decorated_function: self.log.info( f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}" @@ -159,14 +162,11 @@ def task( """ if config is None: config = {} - elif callable(config): - config = config(**kwargs) - elif not isinstance(config, dict): - raise TypeError("config must be either a callable, a dictionary, or None") return task_decorator_factory( python_callable=python_callable, multiple_outputs=multiple_outputs, decorated_operator_class=_RayDecoratedOperator, config=config, + **kwargs, )