diff --git a/dask_databricks/databrickscluster.py b/dask_databricks/databrickscluster.py index 3b064b8..f02a16c 100644 --- a/dask_databricks/databrickscluster.py +++ b/dask_databricks/databrickscluster.py @@ -23,12 +23,19 @@ def __init__( loop: Optional[IOLoop] = None, asynchronous: bool = False, ): - self.spark_local_ip = os.getenv("SPARK_LOCAL_IP") + self.spark_local_ip = os.environ.get("SPARK_LOCAL_IP") if self.spark_local_ip is None: raise KeyError( "Unable to find expected environment variable SPARK_LOCAL_IP. " "Are you running this on a Databricks driver node?" ) + if os.environ.get("MASTER") and "local[" in os.environ.get("MASTER"): + raise EnvironmentError( + "You appear to be trying to run a multi-node Dask cluster on a " + "single-node databricks cluster. Maybe you want " + "`dask.distributed.LocalCluster().get_client()` instead" + + ) try: name = spark.conf.get("spark.databricks.clusterUsageTags.clusterId") except AttributeError: diff --git a/dask_databricks/tests/test_databricks.py b/dask_databricks/tests/test_databricks.py index 62e56c8..c0e09ae 100644 --- a/dask_databricks/tests/test_databricks.py +++ b/dask_databricks/tests/test_databricks.py @@ -38,6 +38,14 @@ def test_databricks_cluster_raises_key_error_when_initialised_outside_of_databri with pytest.raises(KeyError): DatabricksCluster() +def test_databricks_cluster_raises_environment_error_when_master_variable_implies_single_node( + monkeypatch, + set_spark_local_ip, + dask_cluster, +): + monkeypatch.setenv("MASTER", "local[8]") + with pytest.raises(EnvironmentError): + DatabricksCluster() def test_databricks_cluster_create(set_spark_local_ip, dask_cluster): cluster = DatabricksCluster()