Skip to content

Commit

Permalink
support M2M auth
Browse files Browse the repository at this point in the history
  • Loading branch information
rbiseck3 committed Dec 20, 2024
1 parent 49d2717 commit cd07dba
Showing 1 changed file with 46 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

if TYPE_CHECKING:
from databricks.sdk.core import oauth_service_principal
from databricks.sql import Connection as DeltaTableConnection
from databricks.sql.client import Cursor as DeltaTableCursor

Expand All @@ -23,23 +24,63 @@

class DatabrickDeltaTablesAccessConfig(SQLAccessConfig):
token: Optional[str] = Field(default=None, description="Databricks Personal Access Token")
client_id: Optional[str] = Field(default=None, description="Client ID of the OAuth app.")
client_secret: Optional[str] = Field(
default=None, description="Client Secret of the OAuth app."
)


class DatabrickDeltaTablesConnectionConfig(SQLConnectionConfig):
access_config: Secret[DatabrickDeltaTablesAccessConfig]
server_hostname: str = Field(description="server hostname connection config value")
http_path: str = Field(description="http path connection config value")

@requires_dependencies(["databricks"], extras="databricks-delta-tables")
def get_credentials_provider(self) -> "oauth_service_principal":
from databricks.sdk.core import Config, oauth_service_principal

host = f"https://{self.server_hostname}"
access_configs = self.access_config.get_secret_value()
if (client_id := access_configs.client_id) and (
client_secret := access_configs.client_secret
):
return oauth_service_principal(
Config(
host=host,
client_id=client_id,
client_secret=client_secret,
)
)
return False

def model_post_init(self, __context: Any) -> None:
access_config = self.access_config.get_secret_value()
if access_config.token and access_config.client_secret and access_config.client_id:
raise ValueError(
"One one for of auth can be provided, either token or client id and secret"
)
if not access_config.token and not (
access_config.client_secret and access_config.client_id
):
raise ValueError(
"One form of auth must be provided, either token or client id and secret"
)

@contextmanager
@requires_dependencies(["databricks"], extras="databricks-delta-tables")
def get_connection(self) -> Generator["DeltaTableConnection", None, None]:
from databricks.sql import connect

with connect(
server_hostname=self.server_hostname,
http_path=self.http_path,
access_token=self.access_config.get_secret_value().token,
) as connection:
connect_kwargs = {
"server_hostname": self.server_hostname,
"http_path": self.http_path,
}

if credential_provider := self.get_credentials_provider():
connect_kwargs["credentials_provider"] = credential_provider
else:
connect_kwargs["access_token"] = self.access_config.get_secret_value().token
with connect(**connect_kwargs) as connection:
yield connection

@contextmanager
Expand Down

0 comments on commit cd07dba

Please sign in to comment.