Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[issue-458] Log if the user passes a NoneType access token #462

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions src/databricks/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import re

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Union

if TYPE_CHECKING:
# Use this import purely for type annotations, a la https://mypy.readthedocs.io/en/latest/runtime_troubles.html#import-cycles
Expand Down Expand Up @@ -84,7 +84,32 @@ def TimestampFromTicks(ticks):
return Timestamp(*time.localtime(ticks)[:6])


def connect(server_hostname, http_path, access_token=None, **kwargs) -> "Connection":
def singleton(class_):
instances = {}

def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]

return getinstance


@singleton
class DefaultNone(object):
"""Used to represent a default value of None so that this code can distinguish between
the user passing None versus a default value of None being used.
"""

pass


def connect(
server_hostname,
http_path,
access_token: Optional[Union[str, DefaultNone]] = DefaultNone,
**kwargs
) -> "Connection":
from .client import Connection

return Connection(server_hostname, http_path, access_token, **kwargs)
12 changes: 9 additions & 3 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import decimal
from uuid import UUID

from databricks.sql import __version__
from databricks.sql import __version__, DefaultNone
from databricks.sql import *
from databricks.sql.exc import (
OperationalError,
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
self,
server_hostname: str,
http_path: str,
access_token: Optional[str] = None,
access_token: Optional[Union[str, DefaultNone]] = None,
http_headers: Optional[List[Tuple[str, str]]] = None,
session_configuration: Optional[Dict[str, Any]] = None,
catalog: Optional[str] = None,
Expand Down Expand Up @@ -204,7 +204,13 @@ def read(self) -> Optional[OAuthToken]:
# use_cloud_fetch
# Enable use of cloud fetch to extract large query results in parallel via cloud storage

if access_token:
if access_token is DefaultNone:
access_token = None
elif access_token is None:
logger.info(
"Connection access_token was passed a None value. U2M OAuth will be attempted"
)
else:
access_token_kv = {"access_token": access_token}
kwargs = {**kwargs, **access_token_kv}

Expand Down
Loading