Skip to content

Commit

Permalink
Multiple databases (#10)
Browse files Browse the repository at this point in the history
* Manage several databases on server side

* update tests

* Adds a side panel to display databases summary

* Get database schema

* Add a database icon

* Link a database to a cell

* prettier

* lint

* lint

* fix python test

* get databse schema for sync engine

* Add server and application tests

* Adds tests on sidepanels

* lint

* Specify database used for tests
  • Loading branch information
brichet authored Oct 23, 2023
1 parent 95511e2 commit b0fc307
Show file tree
Hide file tree
Showing 18 changed files with 1,364 additions and 51 deletions.
61 changes: 50 additions & 11 deletions jupyter_sql_cell/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from jupyter_server.extension.application import ExtensionApp
from jupyter_server.utils import url_path_join
from traitlets import Unicode
from traitlets import Dict, Integer, List, Unicode

from .handlers import ExampleHandler, ExecuteHandler
from .handlers import DatabasesHandler, DatabaseSchemaHandler, ExampleHandler, ExecuteHandler
from .sqlconnector import SQLConnector


Expand All @@ -13,9 +13,29 @@ class JupyterSqlCell(ExtensionApp):
name = "JupyterSqlCell"
default_url = "/jupyter-sql-cell"

db_url = Unicode(
"",
help="The database URL"
database = Dict(per_key_traits={
"alias": Unicode(default_value=None, allow_none=True),
"database": Unicode(),
"dbms": Unicode(),
"driver": Unicode(default_value=None, allow_none=True),
"host": Unicode(default_value=None, allow_none=True),
"port": Integer(default_value=None, allow_none=True)
},
default_value={},
help="The databases description"
).tag(config=True)

databases = List(
Dict(per_key_traits={
"alias": Unicode(default_value=None, allow_none=True),
"database": Unicode(),
"dbms": Unicode(),
"driver": Unicode(default_value=None, allow_none=True),
"host": Unicode(default_value=None, allow_none=True),
"port": Integer(default_value=None, allow_none=True)
}),
default_value=[],
help="The databases description",
).tag(config=True)


Expand All @@ -24,17 +44,36 @@ def __init__(self) -> None:

def initialize(self):
path = pathlib.Path(__file__)
if not self.db_url:
if self.database:
self.databases.append(self.database)

if not self.databases:
path = pathlib.Path(__file__).parent / "tests" / "data" / "world.sqlite"
self.db_url = f"sqlite+aiosqlite:///{path}"
SQLConnector.db_url = self.db_url
self.databases = [{
"alias": "default",
"database": str(path),
"dbms": "sqlite",
"driver": None,
"host": None,
"port": None
}]
for database in self.databases:
for option in ["alias", "driver", "host", "port"]:
if not option in database.keys():
database[option] = None
SQLConnector.add_database(database)

return super().initialize()

def initialize_handlers(self):
super().initialize_handlers()
example_pattern = url_path_join("/jupyter-sql-cell", "get-example")
execute_pattern = url_path_join("/jupyter-sql-cell", "execute")
example_pattern = url_path_join(self.default_url, "get-example")
databases_pattern = url_path_join(self.default_url, "databases")
execute_pattern = url_path_join(self.default_url, "execute")
schema_pattern = url_path_join(self.default_url, "schema")
self.handlers.extend([
(databases_pattern, DatabasesHandler),
(example_pattern, ExampleHandler),
(execute_pattern, ExecuteHandler)
(execute_pattern, ExecuteHandler),
(schema_pattern, DatabaseSchemaHandler)
])
80 changes: 77 additions & 3 deletions jupyter_sql_cell/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,99 @@
from .sqlconnector import SQLConnector


def reply_error(api: APIHandler, msg: StopIteration):
api.set_status(500)
api.log.error(msg)
reply = {"message": msg}
api.finish(json.dumps(reply))


class DatabasesHandler(APIHandler):
@tornado.web.authenticated
def get(self):
try:
databases = SQLConnector.get_databases()
self.finish(json.dumps(databases))
except Exception as e:
self.log.error(f"Databases error\n{e}")
self.write_error(500, exec_info=e)


class ExecuteHandler(APIHandler):
# The following decorator should be present on all verb methods (head, get, post,
# patch, put, delete, options) to ensure only authorized user can request the
# Jupyter server
@tornado.gen.coroutine
@tornado.web.authenticated
def post(self):
query = json.loads(self.request.body).get("query", None)
body = json.loads(self.request.body)
id = body.get("id", None)
query = body.get("query", None)

if id is None:
reply_error(self, "The database id has not been provided")
return
if not query:
reply_error(self, "No query has been provided")
return
try:
connector = SQLConnector()
connector = SQLConnector(int(id))
if connector.errors:
reply_error(self, connector.errors[0])
return
except Exception as e:
self.log.error(f"Connector error\n{e}")
self.write_error(500, exec_info=e)
return

try:
result = yield connector.execute(query)
self.finish(json.dumps({
"data": result
"alias": connector.database["alias"],
"data": result,
"id": id,
"query": query,
}))
except Exception as e:
self.log.error(f"Query error\n{e}")
self.write_error(500, exec_info=e)


class DatabaseSchemaHandler(APIHandler):
@tornado.gen.coroutine
@tornado.web.authenticated
def get(self):
id = self.get_argument("id", "")
target = self.get_argument("target", "tables")
table = self.get_argument("table", "")

if not id:
reply_error(self, "The database id has not been provided")
return
if target not in ["tables", "columns"]:
reply_error(self, "Target must be \"tables\" or \"columns\"")
return
if target == "columns" and not table:
reply_error(self, "The table has not been provided")
return

try:
connector = SQLConnector(int(id))
if connector.errors:
reply_error(self, connector.errors[0])
return
except Exception as e:
self.log.error(f"Connector error\n{e}")
self.write_error(500, exec_info=e)
return

try:
data = yield connector.get_schema(target, table)
self.finish(json.dumps({
"data": data,
"id": id,
"table": table,
"target": target
}))
except Exception as e:
self.log.error(f"Query error\n{e}")
Expand Down
145 changes: 135 additions & 10 deletions jupyter_sql_cell/sqlconnector.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,71 @@
from jupyter_server import log
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import CursorResult, text
from typing import Any, Dict, List
from sqlalchemy.exc import InvalidRequestError, NoSuchModuleError
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
from sqlalchemy import CursorResult, Inspector, URL, create_engine, inspect, text
from typing import Any, Dict, List, Optional, TypedDict

ASYNC_DRIVERS = {
"mariadb": ["asyncmy", "aiomysql"],
"mysql": ["asyncmy", "aiomysql"],
"postgres": ["asyncpg", "psycopg"],
"sqlite": ["aiosqlite"],
}


class DatabaseDesc(TypedDict):
alias: Optional[str]
database: str
dbms: str
driver: Optional[str]
host: Optional[str]
port: Optional[int]


class Database(TypedDict):
alias: str
id: int
is_async: bool
url: URL


class DatabaseSummary(DatabaseDesc):
id: int
is_async: bool


class SQLConnector:

db_url: str = ""
databases: [Database] = []
warnings = []

def __init__(self, database_id: int):
self.engine = None
self.errors = []
self.database: Database = next(filter(lambda db: db["id"] == database_id, self.databases), None)

engine = None
if not self.database:
self.errors.append(f"There is no registered database with id {database_id}")
else:
if self.database["is_async"]:
self.engine = create_async_engine(self.database["url"])
else:
self.engine = create_engine(self.database["url"])

def __init__(self) -> None:
if not self.db_url:
log.warn("The database URL is not set")
self.engine = create_async_engine(self.db_url)
async def get_schema(self, target: str, table: str = "") -> [str]:
if self.database["is_async"]:
async with self.engine.connect() as conn:
schema = await conn.run_sync(self.use_inspector, target, table)
else:
with self.engine.connect() as conn:
schema = self.use_inspector(conn, target, table)
return schema

def use_inspector(self, conn: AsyncConnection, target: str, table: str) -> [str]:
inspector: Inspector = inspect(conn)
if target == "tables":
return inspector.get_table_names()
elif target == "columns":
columns = inspector.get_columns(table)
return sorted([column['name'] for column in columns])

async def execute(self, query: str) -> str:
if not self.engine:
Expand All @@ -27,6 +79,79 @@ async def execute_request(self, query: str) -> CursorResult[Any]:
cursor: CursorResult[Any] = await connection.execute(text(query))
return cursor

@classmethod
def add_database(cls, db_desc: DatabaseDesc):
id = 0
if cls.databases:
id = max([db["id"] for db in cls.databases]) + 1

if db_desc["alias"]:
alias = db_desc["alias"]
else:
alias = f"{db_desc['dbms']}_{id}"

if db_desc["driver"]:
drivers = [db_desc["driver"]]
else:
drivers = ASYNC_DRIVERS.get(db_desc["dbms"], [])

for driver in drivers:
url = URL.create(
drivername=f"{db_desc['dbms']}+{driver}",
host=db_desc["host"],
port=db_desc["port"],
database=db_desc["database"]
)
try:
create_async_engine(url)
cls.databases.append({
"alias": alias,
"id": id,
"url": url,
"is_async": True
})
return
except (InvalidRequestError, NoSuchModuleError):
# InvalidRequestError is raised if the driver is not async.
# NoSuchModuleError is raised if the driver is not installed.
continue

driver = f"+{db_desc['driver']}" if db_desc["driver"] else ""
url = URL.create(
drivername=f"{db_desc['dbms']}{driver}",
host=db_desc["host"],
port=db_desc["port"],
database=db_desc["database"]
)
create_engine(url)
cls.databases.append({
"alias": alias,
"id": id,
"url": url,
"is_async": False
})
cls.warnings.append("No async driver found, the query will be executed synchronously")
print(cls.warnings[-1])

@classmethod
def get_databases(cls):
summary_databases: [DatabaseSummary] = []
for database in cls.databases:
url: URL = database["url"]
summary: DatabaseSummary = {
"alias": database["alias"],
"database": url.database,
"driver": url.drivername,
"id": database["id"],
"is_async": database["is_async"]
}
if url.host:
summary["host"] = url.host
if url.port:
summary["port"] = url.port
summary_databases.append(summary)
return summary_databases

@staticmethod
def to_list(cursor: CursorResult[Any]) -> List[Dict]:
return [row._asdict() for row in cursor.fetchall()]
Binary file added jupyter_sql_cell/tests/data/chinook.db
Binary file not shown.
Loading

0 comments on commit b0fc307

Please sign in to comment.