diff --git a/jupyter_sql_cell/app.py b/jupyter_sql_cell/app.py index 4f626ba..221cb16 100644 --- a/jupyter_sql_cell/app.py +++ b/jupyter_sql_cell/app.py @@ -2,9 +2,9 @@ from jupyter_server.extension.application import ExtensionApp from jupyter_server.utils import url_path_join -from traitlets import Unicode +from traitlets import Dict, Integer, List, Unicode -from .handlers import ExampleHandler, ExecuteHandler +from .handlers import DatabasesHandler, DatabaseSchemaHandler, ExampleHandler, ExecuteHandler from .sqlconnector import SQLConnector @@ -13,9 +13,29 @@ class JupyterSqlCell(ExtensionApp): name = "JupyterSqlCell" default_url = "/jupyter-sql-cell" - db_url = Unicode( - "", - help="The database URL" + database = Dict(per_key_traits={ + "alias": Unicode(default_value=None, allow_none=True), + "database": Unicode(), + "dbms": Unicode(), + "driver": Unicode(default_value=None, allow_none=True), + "host": Unicode(default_value=None, allow_none=True), + "port": Integer(default_value=None, allow_none=True) + }, + default_value={}, + help="The databases description" + ).tag(config=True) + + databases = List( + Dict(per_key_traits={ + "alias": Unicode(default_value=None, allow_none=True), + "database": Unicode(), + "dbms": Unicode(), + "driver": Unicode(default_value=None, allow_none=True), + "host": Unicode(default_value=None, allow_none=True), + "port": Integer(default_value=None, allow_none=True) + }), + default_value=[], + help="The databases description", ).tag(config=True) @@ -24,17 +44,36 @@ def __init__(self) -> None: def initialize(self): path = pathlib.Path(__file__) - if not self.db_url: + if self.database: + self.databases.append(self.database) + + if not self.databases: path = pathlib.Path(__file__).parent / "tests" / "data" / "world.sqlite" - self.db_url = f"sqlite+aiosqlite:///{path}" - SQLConnector.db_url = self.db_url + self.databases = [{ + "alias": "default", + "database": str(path), + "dbms": "sqlite", + "driver": None, + "host": None, + "port": None + }] + for database in self.databases: + for option in ["alias", "driver", "host", "port"]: + if not option in database.keys(): + database[option] = None + SQLConnector.add_database(database) + return super().initialize() def initialize_handlers(self): super().initialize_handlers() - example_pattern = url_path_join("/jupyter-sql-cell", "get-example") - execute_pattern = url_path_join("/jupyter-sql-cell", "execute") + example_pattern = url_path_join(self.default_url, "get-example") + databases_pattern = url_path_join(self.default_url, "databases") + execute_pattern = url_path_join(self.default_url, "execute") + schema_pattern = url_path_join(self.default_url, "schema") self.handlers.extend([ + (databases_pattern, DatabasesHandler), (example_pattern, ExampleHandler), - (execute_pattern, ExecuteHandler) + (execute_pattern, ExecuteHandler), + (schema_pattern, DatabaseSchemaHandler) ]) diff --git a/jupyter_sql_cell/handlers.py b/jupyter_sql_cell/handlers.py index d8f9d28..eeb99d3 100644 --- a/jupyter_sql_cell/handlers.py +++ b/jupyter_sql_cell/handlers.py @@ -6,6 +6,24 @@ from .sqlconnector import SQLConnector +def reply_error(api: APIHandler, msg: StopIteration): + api.set_status(500) + api.log.error(msg) + reply = {"message": msg} + api.finish(json.dumps(reply)) + + +class DatabasesHandler(APIHandler): + @tornado.web.authenticated + def get(self): + try: + databases = SQLConnector.get_databases() + self.finish(json.dumps(databases)) + except Exception as e: + self.log.error(f"Databases error\n{e}") + self.write_error(500, exec_info=e) + + class ExecuteHandler(APIHandler): # The following decorator should be present on all verb methods (head, get, post, # patch, put, delete, options) to ensure only authorized user can request the @@ -13,18 +31,74 @@ class ExecuteHandler(APIHandler): @tornado.gen.coroutine @tornado.web.authenticated def post(self): - query = json.loads(self.request.body).get("query", None) + body = json.loads(self.request.body) + id = body.get("id", None) + query = body.get("query", None) + if id is None: + reply_error(self, "The database id has not been provided") + return + if not query: + reply_error(self, "No query has been provided") + return try: - connector = SQLConnector() + connector = SQLConnector(int(id)) + if connector.errors: + reply_error(self, connector.errors[0]) + return except Exception as e: self.log.error(f"Connector error\n{e}") self.write_error(500, exec_info=e) + return try: result = yield connector.execute(query) self.finish(json.dumps({ - "data": result + "alias": connector.database["alias"], + "data": result, + "id": id, + "query": query, + })) + except Exception as e: + self.log.error(f"Query error\n{e}") + self.write_error(500, exec_info=e) + + +class DatabaseSchemaHandler(APIHandler): + @tornado.gen.coroutine + @tornado.web.authenticated + def get(self): + id = self.get_argument("id", "") + target = self.get_argument("target", "tables") + table = self.get_argument("table", "") + + if not id: + reply_error(self, "The database id has not been provided") + return + if target not in ["tables", "columns"]: + reply_error(self, "Target must be \"tables\" or \"columns\"") + return + if target == "columns" and not table: + reply_error(self, "The table has not been provided") + return + + try: + connector = SQLConnector(int(id)) + if connector.errors: + reply_error(self, connector.errors[0]) + return + except Exception as e: + self.log.error(f"Connector error\n{e}") + self.write_error(500, exec_info=e) + return + + try: + data = yield connector.get_schema(target, table) + self.finish(json.dumps({ + "data": data, + "id": id, + "table": table, + "target": target })) except Exception as e: self.log.error(f"Query error\n{e}") diff --git a/jupyter_sql_cell/sqlconnector.py b/jupyter_sql_cell/sqlconnector.py index a250bcb..893b386 100644 --- a/jupyter_sql_cell/sqlconnector.py +++ b/jupyter_sql_cell/sqlconnector.py @@ -1,19 +1,71 @@ -from jupyter_server import log -from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy import CursorResult, text -from typing import Any, Dict, List +from sqlalchemy.exc import InvalidRequestError, NoSuchModuleError +from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine +from sqlalchemy import CursorResult, Inspector, URL, create_engine, inspect, text +from typing import Any, Dict, List, Optional, TypedDict + +ASYNC_DRIVERS = { + "mariadb": ["asyncmy", "aiomysql"], + "mysql": ["asyncmy", "aiomysql"], + "postgres": ["asyncpg", "psycopg"], + "sqlite": ["aiosqlite"], +} + + +class DatabaseDesc(TypedDict): + alias: Optional[str] + database: str + dbms: str + driver: Optional[str] + host: Optional[str] + port: Optional[int] + + +class Database(TypedDict): + alias: str + id: int + is_async: bool + url: URL + + +class DatabaseSummary(DatabaseDesc): + id: int + is_async: bool class SQLConnector: - db_url: str = "" + databases: [Database] = [] + warnings = [] + + def __init__(self, database_id: int): + self.engine = None + self.errors = [] + self.database: Database = next(filter(lambda db: db["id"] == database_id, self.databases), None) - engine = None + if not self.database: + self.errors.append(f"There is no registered database with id {database_id}") + else: + if self.database["is_async"]: + self.engine = create_async_engine(self.database["url"]) + else: + self.engine = create_engine(self.database["url"]) - def __init__(self) -> None: - if not self.db_url: - log.warn("The database URL is not set") - self.engine = create_async_engine(self.db_url) + async def get_schema(self, target: str, table: str = "") -> [str]: + if self.database["is_async"]: + async with self.engine.connect() as conn: + schema = await conn.run_sync(self.use_inspector, target, table) + else: + with self.engine.connect() as conn: + schema = self.use_inspector(conn, target, table) + return schema + + def use_inspector(self, conn: AsyncConnection, target: str, table: str) -> [str]: + inspector: Inspector = inspect(conn) + if target == "tables": + return inspector.get_table_names() + elif target == "columns": + columns = inspector.get_columns(table) + return sorted([column['name'] for column in columns]) async def execute(self, query: str) -> str: if not self.engine: @@ -27,6 +79,79 @@ async def execute_request(self, query: str) -> CursorResult[Any]: cursor: CursorResult[Any] = await connection.execute(text(query)) return cursor + @classmethod + def add_database(cls, db_desc: DatabaseDesc): + id = 0 + if cls.databases: + id = max([db["id"] for db in cls.databases]) + 1 + + if db_desc["alias"]: + alias = db_desc["alias"] + else: + alias = f"{db_desc['dbms']}_{id}" + + if db_desc["driver"]: + drivers = [db_desc["driver"]] + else: + drivers = ASYNC_DRIVERS.get(db_desc["dbms"], []) + + for driver in drivers: + url = URL.create( + drivername=f"{db_desc['dbms']}+{driver}", + host=db_desc["host"], + port=db_desc["port"], + database=db_desc["database"] + ) + try: + create_async_engine(url) + cls.databases.append({ + "alias": alias, + "id": id, + "url": url, + "is_async": True + }) + return + except (InvalidRequestError, NoSuchModuleError): + # InvalidRequestError is raised if the driver is not async. + # NoSuchModuleError is raised if the driver is not installed. + continue + + driver = f"+{db_desc['driver']}" if db_desc["driver"] else "" + url = URL.create( + drivername=f"{db_desc['dbms']}{driver}", + host=db_desc["host"], + port=db_desc["port"], + database=db_desc["database"] + ) + create_engine(url) + cls.databases.append({ + "alias": alias, + "id": id, + "url": url, + "is_async": False + }) + cls.warnings.append("No async driver found, the query will be executed synchronously") + print(cls.warnings[-1]) + + @classmethod + def get_databases(cls): + summary_databases: [DatabaseSummary] = [] + for database in cls.databases: + url: URL = database["url"] + summary: DatabaseSummary = { + "alias": database["alias"], + "database": url.database, + "driver": url.drivername, + "id": database["id"], + "is_async": database["is_async"] + } + if url.host: + summary["host"] = url.host + if url.port: + summary["port"] = url.port + summary_databases.append(summary) + return summary_databases + @staticmethod def to_list(cursor: CursorResult[Any]) -> List[Dict]: return [row._asdict() for row in cursor.fetchall()] diff --git a/jupyter_sql_cell/tests/data/chinook.db b/jupyter_sql_cell/tests/data/chinook.db new file mode 100644 index 0000000..38a98b3 Binary files /dev/null and b/jupyter_sql_cell/tests/data/chinook.db differ diff --git a/jupyter_sql_cell/tests/test_connector.py b/jupyter_sql_cell/tests/test_connector.py new file mode 100644 index 0000000..dd620f3 --- /dev/null +++ b/jupyter_sql_cell/tests/test_connector.py @@ -0,0 +1,107 @@ +import pytest +from pathlib import Path +from jupyter_sql_cell.sqlconnector import SQLConnector + +def teardown_function(): + SQLConnector.databases = [] + + +@pytest.fixture +def db_path() -> Path: + return Path(__file__).parent / "data" / "world.sqlite" + + +@pytest.fixture +def add_database(db_path): + SQLConnector.add_database({ + "alias": "default", + "database": str(db_path), + "dbms": "sqlite", + "driver": None, + "host": None, + "port": None + }) + + +@pytest.fixture +def add_sync_database(db_path): + SQLConnector.add_database({ + "alias": "default", + "database": str(db_path), + "dbms": "sqlite", + "driver": "pysqlite", + "host": None, + "port": None + }) + + +""" +Should create an SqlConnector object without database. +""" +async def test_sql_connector_init(): + assert len(SQLConnector.databases) == 0 + connector = SQLConnector(0) + assert type(connector) == SQLConnector + assert len(connector.databases) == 0 + assert len(connector.errors) == 1 + assert "no registered database with id 0" in connector.errors[0] + + +""" +Should add an async database. +""" +async def test_sql_connector_without_driver(add_database): + assert len(SQLConnector.databases) == 1 + assert len(SQLConnector.warnings) == 0 + connector = SQLConnector(0) + assert len(connector.errors) == 0 + + +""" +Should add a sync database. +""" +async def test_sql_connector_with_sync_driver(add_sync_database): + assert len(SQLConnector.databases) == 1 + assert len(SQLConnector.warnings) == 1 + connector = SQLConnector(0) + assert len(connector.errors) == 0 + + +""" +Should return tables list on async database. +""" +async def test_schema_tables(add_database): + connector = SQLConnector(0) + schema = await connector.get_schema("tables") + assert len(schema) == 1 + assert schema == ["world"] + + +""" +Should return tables list on sync database. +""" +async def test_schema_tables_sync(add_sync_database): + connector = SQLConnector(0) + schema = await connector.get_schema("tables") + assert len(schema) == 1 + assert schema == ["world"] + + +""" +Should return columns list on async database. +""" +async def test_schema_columns(add_database): + connector = SQLConnector(0) + schema = await connector.get_schema("columns", "world") + assert len(schema) == 35 + assert "Abbreviation" in schema + + +""" +Should return columns list on sync database. +""" +async def test_schema_columns_sync(add_sync_database): + connector = SQLConnector(0) + schema = await connector.get_schema("columns", "world") + assert len(schema) == 35 + assert "Abbreviation" in schema diff --git a/jupyter_sql_cell/tests/test_handlers.py b/jupyter_sql_cell/tests/test_handlers.py index 92853b0..2b4cbbf 100644 --- a/jupyter_sql_cell/tests/test_handlers.py +++ b/jupyter_sql_cell/tests/test_handlers.py @@ -1,4 +1,6 @@ +import pytest import json +from tornado.httpclient import HTTPClientError async def test_get_example(jp_fetch): @@ -12,16 +14,121 @@ async def test_get_example(jp_fetch): "data": "This is /jupyter-sql-cell/get-example endpoint!" } -async def test_execute(jp_fetch): + +""" +Should load and query the default database when none has been provided in config. +""" +async def test_execute_default(jp_fetch): + query = "SELECT Abbreviation FROM world WHERE Country='France'" response = await jp_fetch( "jupyter-sql-cell", "execute", - body=json.dumps({"query": "SELECT Abbreviation FROM world WHERE Country='France'"}), + body=json.dumps({ + "query": query, + "id": "0" + }), method="POST" ) assert response.code == 200 payload = json.loads(response.body) - assert payload == { - "data": [{"Abbreviation": "FR"}] - } + assert list(payload.keys()) == ["alias", "data", "id", "query"] + assert payload["data"] == [{"Abbreviation": "FR"}] + assert payload["query"] == query + assert payload["id"] == "0" + + +""" +Should raise if the database ID has not been provided. +""" +async def test_execute_no_id(jp_fetch): + with pytest.raises(HTTPClientError): + response = await jp_fetch( + "jupyter-sql-cell", + "execute", + body=json.dumps({ + "query": "SELECT Abbreviation FROM world WHERE Country='France'" + }), + method="POST" + ) + + +""" +Should raise if the query has not been provided. +""" +async def test_execute_no_query(jp_fetch): + with pytest.raises(HTTPClientError): + response = await jp_fetch( + "jupyter-sql-cell", + "execute", + body=json.dumps({ + "id": "0" + }), + method="POST" + ) + + +""" +Should return the tables list. +""" +async def test_get_tables(jp_fetch): + response = await jp_fetch( + "jupyter-sql-cell", + "schema", + params=[("id", "0"), ("target", "tables")] + ) + assert response.code == 200 + payload = json.loads(response.body) + assert payload["data"] == ["world"] + + +""" +Should return the tables list if no target is defined. +""" +async def test_get_tables_no_target(jp_fetch): + response = await jp_fetch( + "jupyter-sql-cell", + "schema", + params=[("id", "0")] + ) + assert response.code == 200 + payload = json.loads(response.body) + assert payload["data"] == ["world"] + + +""" +Should return the column names. +""" +async def test_get_columns(jp_fetch): + response = await jp_fetch( + "jupyter-sql-cell", + "schema", + params=[("id", "0"), ("target", "columns"), ("table", "world")] + ) + assert response.code == 200 + payload = json.loads(response.body) + assert "Abbreviation" in payload["data"] + + +""" +Should raise if the table is not provided. +""" +async def test_get_columns_no_table(jp_fetch): + with pytest.raises(HTTPClientError): + response = await jp_fetch( + "jupyter-sql-cell", + "schema", + params=[("id", "0"), ("target", "columns")] + ) + + +""" +Should raise if the target is wrong. +""" +async def test_get_schema_wrong_target(jp_fetch): + with pytest.raises(HTTPClientError): + response = await jp_fetch( + "jupyter-sql-cell", + "schema", + params=[("id", "0"), ("target", "fake")] + ) diff --git a/src/index.ts b/src/index.ts index 900bf83..5f19a18 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,7 @@ import { Parser } from '@json2csv/plainjs'; import { + ILabShell, + ILayoutRestorer, JupyterFrontEnd, JupyterFrontEndPlugin } from '@jupyterlab/application'; @@ -9,13 +11,20 @@ import { IDefaultFileBrowser } from '@jupyterlab/filebrowser'; import { INotebookTracker, NotebookPanel } from '@jupyterlab/notebook'; import { Contents, ContentsManager } from '@jupyterlab/services'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; +import { ITranslator, nullTranslator } from '@jupyterlab/translation'; import { runIcon } from '@jupyterlab/ui-components'; import { CustomContentFactory } from './cellfactory'; import { requestAPI } from './handler'; import { CommandIDs, SQL_MIMETYPE, SqlCell } from './common'; +import { Databases, DATABASE_METADATA } from './sidepanel'; import { SqlWidget } from './widget'; +/** + * The sql-cell namespace token. + */ +const namespace = 'sql-cell'; + /** * Load the commands and the cell toolbar buttons (from settings). */ @@ -45,11 +54,17 @@ const plugin: JupyterFrontEndPlugin = { if (!(activeCell?.model.type === 'raw')) { return; } + const database_id = + activeCell.model.getMetadata(DATABASE_METADATA)['id']; + + if (database_id === undefined) { + console.error('The database has not been set.'); + } const date = new Date(); const source = activeCell?.model.sharedModel.getSource(); requestAPI('execute', { method: 'POST', - body: JSON.stringify({ query: source }) + body: JSON.stringify({ query: source, id: database_id }) }) .then(data => { Private.saveData(path, data.data, date, fileBrowser) @@ -57,9 +72,7 @@ const plugin: JupyterFrontEndPlugin = { .catch(undefined); }) .catch(reason => { - console.error( - `The jupyter_sql_cell server extension appears to be missing.\n${reason}` - ); + console.error(reason); }); }, isEnabled: () => SqlCell.isSqlCell(tracker.activeCell?.model), @@ -118,6 +131,44 @@ const cellFactory: JupyterFrontEndPlugin = { } }; +/** + * The side panel to handle the list of databases. + */ +const databasesList: JupyterFrontEndPlugin = { + id: '@jupyter/sql-cell:databases-list', + description: 'The side panel which handle databases list.', + autoStart: true, + optional: [ILabShell, ILayoutRestorer, INotebookTracker, ITranslator], + activate: ( + app: JupyterFrontEnd, + labShell: ILabShell, + restorer: ILayoutRestorer | null, + tracker: INotebookTracker | null, + translator: ITranslator | null + ) => { + const { shell } = app; + if (!translator) { + translator = nullTranslator; + } + const panel = new Databases({ tracker, translator }); + + // Restore the widget state + if (restorer) { + restorer.add(panel, namespace); + } + + if (labShell) { + labShell.currentChanged.connect( + (_: ILabShell, args: ILabShell.IChangedArgs) => { + panel.mainAreaWidgetChanged(args.newValue); + } + ); + } + + shell.add(panel, 'left'); + } +}; + /** * The notebook toolbar widget. */ @@ -161,7 +212,7 @@ const notebookToolbarWidget: JupyterFrontEndPlugin = { } }; -export default [cellFactory, notebookToolbarWidget, plugin]; +export default [cellFactory, databasesList, notebookToolbarWidget, plugin]; namespace Private { /** diff --git a/src/sidepanel.tsx b/src/sidepanel.tsx new file mode 100644 index 0000000..2ad1a4d --- /dev/null +++ b/src/sidepanel.tsx @@ -0,0 +1,511 @@ +import { Cell, ICellModel } from '@jupyterlab/cells'; +import { INotebookTracker } from '@jupyterlab/notebook'; +import { ITranslator } from '@jupyterlab/translation'; +import { + LabIcon, + PanelWithToolbar, + ReactWidget, + SidePanel, + ToolbarButton, + UseSignal, + caretDownIcon, + caretRightIcon, + deleteIcon, + tableRowsIcon +} from '@jupyterlab/ui-components'; +import { Signal } from '@lumino/signaling'; +import { AccordionPanel, Panel, Widget } from '@lumino/widgets'; +import * as React from 'react'; + +import { SqlCell } from './common'; +import { requestAPI } from './handler'; +import databaseSvgstr from '../style/icons/database.svg'; + +/** + * The metadata key to store the database. + */ +export const DATABASE_METADATA = 'sqlcell-database'; + +/** + * The class of the side panel. + */ +const DATABASES_CLASS = 'jp-sqlcell-databases-panel'; +/** + * The class of the side panel. + */ +const DATABASE_CLASS = 'jp-sqlcell-database'; +/** + * The class of the database toolbar. + */ +const TOOLBAR_CLASS = 'jp-sqlcell-database-toolbar'; +/** + * The class of the button in database toolbar. + */ +const SELECT_BUTTON_CLASS = 'jp-sqlcell-database-selectbutton'; +/** + * The class of the body of the database. + */ +const DATABASE_BODY_CLASS = 'jp-sqlcell-database-body'; +/** + * The class of tables list. + */ +const TABLE_ITEMS_CLASS = 'jp-sqlcell-table-items'; +/** + * The class of table name. + */ +const TABLE_TITLE_CLASS = 'jp-sqlcell-table-title'; +/** + * The class of the column item. + */ +const COLUMN_ITEMS_CLASS = 'jp-sqlcell-column-items'; + +/** + * The database icon. + */ +const databaseIcon = new LabIcon({ + name: 'sql-cell:database', + svgstr: databaseSvgstr +}); + +/** + * The side panel containing the list of the databases. + */ +export class Databases extends SidePanel { + /** + * Constructor of the databases list. + */ + constructor(options: Databases.IOptions) { + super({ translator: options.translator }); + this.id = 'jp-sql-cell-sidebar'; + this.addClass(DATABASES_CLASS); + this.title.icon = databaseIcon; + this.title.caption = 'Databases'; + this._tracker = options.tracker; + + requestAPI('databases') + .then(data => { + this._buildDatabaseSections(data, options.tracker); + }) + .catch(reason => { + console.error(reason); + }); + + const content = this.content as AccordionPanel; + content.expansionToggled.connect(this._onExpansionToogled, this); + this._tracker?.activeCellChanged.connect(this.activeCellChanged, this); + } + + /** + * Triggered when the main area widget changes. + * + * @param widget - the current main area widget. + */ + mainAreaWidgetChanged(widget: Widget | null) { + if (widget && widget === this._tracker?.currentWidget) { + if (!this._isNotebook) { + this._isNotebook = true; + this.updateSelectButtons(this._tracker?.activeCell?.model); + } + } else { + if (this._isNotebook) { + this._isNotebook = false; + this.updateSelectButtons(this._tracker?.activeCell?.model); + } + } + } + + /** + * Triggered when the active cell changes. + */ + activeCellChanged = (_: INotebookTracker, cell: Cell | null) => { + this._currentCell?.model.metadataChanged.disconnect( + this.cellMetadataChanged, + this + ); + + this._currentCell = cell; + this.updateSelectButtons(cell?.model); + + this._currentCell?.model.metadataChanged.connect( + this.cellMetadataChanged, + this + ); + }; + + /** + * Triggered when the active cell metadata changes. + */ + cellMetadataChanged = (cellModel: ICellModel) => { + this.updateSelectButtons(cellModel); + }; + + /** + * Updates the status of the toolbar button to select the cell database. + * + * @param cellModel - the active cell model. + */ + updateSelectButtons = (cellModel: ICellModel | undefined) => { + const enabled = this._isNotebook && SqlCell.isSqlCell(cellModel); + this.widgets.forEach(widget => { + (widget as DatabaseSection).updateSelectButton( + this._isNotebook, + enabled, + cellModel + ); + }); + }; + + /** + * Build the database sections. + * + * @param databases - the databases description. + * @param tracker - the notebook tracker. + */ + private _buildDatabaseSections( + databases: Databases.IDatabase[], + tracker: INotebookTracker | null + ) { + const content = this.content as AccordionPanel; + databases.forEach(database => { + this.addWidget(new DatabaseSection({ database, tracker })); + content.collapse(content.widgets.length - 1); + }); + } + + /** + * Triggered when the section is expanded. + */ + private _onExpansionToogled(_: AccordionPanel, index: number) { + const section = this.widgets[index] as DatabaseSection; + if (section.isVisible) { + section.onExpand(); + } + } + + private _isNotebook: boolean = false; + private _tracker: INotebookTracker | null; + private _currentCell: Cell | null = null; +} + +/** + * Namespace for the databases side panel. + */ +namespace Databases { + /** + * Options of the databases side panel's constructor. + */ + export interface IOptions { + /** + * The notebook tracker. + */ + tracker: INotebookTracker | null; + /** + * The translator. + */ + translator: ITranslator; + } + + /** + * Database object returned from server request. + */ + export interface IDatabase { + alias: string; + database: string; + driver: string; + id: number; + is_async: boolean; + host?: string; + port?: number; + } +} + +/** + * The database section containing the list of the tables. + */ +class DatabaseSection extends PanelWithToolbar { + constructor(options: DatabaseSection.IOptions) { + super(options); + this._database = options.database; + this._tracker = options.tracker; + this.addClass(DATABASE_CLASS); + this.title.label = this._database.alias; + this.title.caption = this._tooltip(); + this.toolbar.addClass(TOOLBAR_CLASS); + + this._selectButton = new ToolbarButton({ + label: 'SELECT', + className: `${SELECT_BUTTON_CLASS} jp-mod-styled`, + enabled: SqlCell.isSqlCell(this._tracker?.activeCell?.model), + onClick: () => { + const model = this._tracker?.activeCell?.model; + if (this._selectButton.pressed) { + model?.deleteMetadata(DATABASE_METADATA); + this.updateSelectButton(true, true, model); + } else if (model && SqlCell.isSqlCell(model)) { + model.setMetadata(DATABASE_METADATA, this._database); + (this.parent?.parent as Databases)?.updateSelectButtons( + this._tracker?.activeCell?.model + ); + } + } + }); + this.toolbar.addItem('SqlCell-database-select', this._selectButton); + + const deleteButton = new ToolbarButton({ + icon: deleteIcon, + className: 'jp-mod-styled', + onClick: () => { + console.log('should remove the database'); + } + }); + this.toolbar.addItem('SqlCell-database-delete', deleteButton); + + this._body = new TablesList({ database_id: this._database.id }); + this._body.addClass(DATABASE_BODY_CLASS); + this.addWidget(this._body); + } + + /** + * Update the select button status. + * + * @param enabled - whether the button is enabled or not. + * @param cellModel - the active cell model. + */ + updateSelectButton( + visible: boolean, + enabled: boolean, + cellModel: ICellModel | undefined + ) { + const button = this._selectButton; + + if (visible) { + button.removeClass('lm-mod-hidden'); + } else { + button.addClass('lm-mod-hidden'); + } + + button.enabled = enabled; + const metadata = cellModel?.getMetadata(DATABASE_METADATA); + button.pressed = Private.databaseMatch(this._database, metadata); + + // FIXME: should be implemented in ToolbarButton. + button.node.ariaPressed = button.pressed.toString(); + } + + /** + * request the server to get the table list on first expand. + */ + onExpand() { + if (!this._tables.length) { + const searchParams = new URLSearchParams({ + id: this._database.id.toString(), + target: 'tables' + }); + requestAPI(`schema?${searchParams.toString()}`) + .then(response => { + this._tables = (response as DatabaseSection.IDatabaseSchema).data; + this._body.updateTables(this._tables); + }) + .catch(reason => { + console.error(reason); + }); + } + } + + /** + * Build the tooltip text of the toolbar. + */ + private _tooltip() { + let tooltip = ''; + let key: keyof Databases.IDatabase; + for (key in this._database) { + tooltip = tooltip + `${key}: ${this._database[key]?.toString()}\n`; + } + return tooltip; + } + + private _database: Databases.IDatabase; + private _tracker: INotebookTracker | null; + private _body: TablesList; + private _tables: string[] = []; + private _selectButton: ToolbarButton; +} + +/** + * Namespace for the database section. + */ +namespace DatabaseSection { + /** + * Options for the DatabaseSection constructor. + */ + export interface IOptions extends Panel.IOptions { + /** + * The database description. + */ + database: Databases.IDatabase; + /** + * The notebook tracker. + */ + tracker: INotebookTracker | null; + } + + /** + * Schema object returned from server request. + */ + export interface IDatabaseSchema { + data: string[]; + id: number; + table: string; + target: 'tables' | 'columns'; + } +} + +/** + * The tables list. + */ +class TablesList extends ReactWidget { + /** + * Constructor of the tables list. + * @param options - contains the database_id. + */ + constructor(options: { database_id: number }) { + super(); + this._database_id = options.database_id; + } + + updateTables(tables: string[]) { + this._update.emit(tables); + } + + render(): JSX.Element { + const database_id = this._database_id; + return ( +
    + + {(_, tables) => { + return tables?.map(table => ( + + )); + }} + + + ); + } + + private _database_id: number; + private _update: Signal = new Signal(this); +} + +/** + * The table item. + * + * @param database_id - the id of the database to which this table belongs. + * @param name - the name of the table. + */ +const Table = ({ + database_id, + name +}: { + database_id: number; + name: string; +}): JSX.Element => { + const expandedClass = 'lm-mod-expanded'; + const [columns, updateColumns] = React.useState([]); + const [expanded, expand] = React.useState(false); + + /** + * Handle the click on the table name. + * + * @param event - the mouse event + */ + const handleClick = (event: React.MouseEvent) => { + const target = event.target as HTMLDivElement; + if (expanded) { + target.classList.remove(expandedClass); + updateColumns([]); + expand(false); + } else { + target.classList.add(expandedClass); + const searchParams = new URLSearchParams({ + id: database_id.toString(), + table: name, + target: 'columns' + }); + requestAPI(`schema?${searchParams.toString()}`) + .then(response => { + updateColumns((response as DatabaseSection.IDatabaseSchema).data); + expand(true); + }) + .catch(reason => { + console.error(reason); + }); + } + }; + + return ( +
  • +
    + {expanded ? ( + + ) : ( + + )} + + {name} +
    + +
  • + ); +}; + +/** + * The columns list. + * + * @param columns - the list of columns name. + * @returns + */ +const ColumnsList = ({ columns }: { columns: string[] }): JSX.Element => { + return ( +
      + {columns.map(column => ( +
    • {column}
    • + ))} +
    + ); +}; + +namespace Private { + export function databaseMatch( + db1: Databases.IDatabase, + db2: Databases.IDatabase + ): boolean { + if (!db1 || !db2) { + return false; + } + + const keys1 = Object.keys(db1); + const keys2 = Object.keys(db2); + + if (keys1.length !== keys2.length) { + return false; + } + + for (const key of keys1) { + if ( + db1[key as keyof Databases.IDatabase] !== + db2[key as keyof Databases.IDatabase] + ) { + return false; + } + } + + return true; + } +} diff --git a/src/svg.d.ts b/src/svg.d.ts new file mode 100644 index 0000000..d5419ae --- /dev/null +++ b/src/svg.d.ts @@ -0,0 +1,13 @@ +// Copyright (c) Jupyter Development Team. +// Distributed under the terms of the Modified BSD License. + +// including this file in a package allows for the use of import statements +// with svg files. Example: `import xSvg from 'path/xSvg.svg'` + +// for use with raw-loader in Webpack. +// The svg will be imported as a raw string + +declare module '*.svg' { + const value: string; // @ts-ignore + export default value; +} diff --git a/src/widget.tsx b/src/widget.tsx index ea40da1..e1397f0 100644 --- a/src/widget.tsx +++ b/src/widget.tsx @@ -67,7 +67,7 @@ export class SqlWidget extends ReactWidget { {() => (
    SQL cell diff --git a/style/base.css b/style/base.css index dbc0f07..2da6af1 100644 --- a/style/base.css +++ b/style/base.css @@ -4,17 +4,61 @@ https://jupyterlab.readthedocs.io/en/stable/developer/css.html */ -.sql-cell-widget { +/* SIDE PANEL */ + +.jp-sqlcell-database-body .lm-mod-hidden { + display: none; +} + +ul.jp-sqlcell-table-items { + display: block; + margin: 0; + padding-inline-start: 1em; + list-style-type: none; +} + +.jp-sqlcell-table-title { + cursor: pointer; +} + +.jp-sqlcell-table-title:hover { + background-color: var(--jp-layout-color2); +} + +.jp-sqlcell-table-title > span { + vertical-align: middle; +} + +ul.jp-sqlcell-column-items { + padding-inline-start: 2em; + list-style-type: disc; +} + +.jp-sqlcell-database-selectbutton[aria-pressed='true'] { + background-color: var(--jp-inverse-layout-color3); +} + +.jp-sqlcell-database-selectbutton[aria-pressed='true']:hover { + background-color: var(--jp-inverse-layout-color4); +} + +.jp-sqlcell-database-selectbutton[aria-pressed='true'] span { + color: var(--jp-ui-inverse-font-color1); +} + +/* NOTEBOOK TOOLBAR WIDGET */ + +.jp-sqlcell-widget { display: flex; border: solid 1px; align-items: center; } -.sql-cell-widget[aria-disabled='true'] span:first-child { +.jp-sqlcell-widget[aria-disabled='true'] span:first-child { color: var(--jp-ui-font-color3); } -.sql-cell-widget .switch { +.jp-sqlcell-widget .switch { display: flex; justify-content: center; align-content: center; @@ -23,14 +67,14 @@ } /* Hide default HTML checkbox */ -.sql-cell-widget .switch input { +.jp-sqlcell-widget .switch input { opacity: 0; width: 0; height: 0; } /* The slider */ -.sql-cell-widget .slider { +.jp-sqlcell-widget .slider { cursor: pointer; height: 15px; width: 30px; @@ -40,7 +84,7 @@ border-radius: 15px; } -.sql-cell-widget .slider::before { +.jp-sqlcell-widget .slider::before { display: block; content: ''; height: 15px; @@ -51,30 +95,32 @@ border-radius: 50%; } -.sql-cell-widget input:checked + .slider { +.jp-sqlcell-widget input:checked + .slider { background-color: var(--md-blue-500); } -.sql-cell-widget input:focus + .slider { +.jp-sqlcell-widget input:focus + .slider { box-shadow: 0 0 1px var(--md-blue-500); } -.sql-cell-widget input:disabled + .slider { +.jp-sqlcell-widget input:disabled + .slider { background-color: var(--jp-layout-color2); cursor: not-allowed; } -.sql-cell-widget input:checked + .slider::before { +.jp-sqlcell-widget input:checked + .slider::before { -webkit-transform: translateX(15px); -ms-transform: translateX(15px); transform: translateX(15px); } /* stylelint-disable-next-line selector-class-pattern */ -.sql-cell-widget .jp-ToolbarButtonComponent { +.jp-sqlcell-widget .jp-ToolbarButtonComponent { height: 22px; } +/* CELL TOOLBAR */ + .jp-cell-toolbar .lm-mod-hidden { display: none; } diff --git a/style/icons/database.svg b/style/icons/database.svg new file mode 100644 index 0000000..0d4276c --- /dev/null +++ b/style/icons/database.svg @@ -0,0 +1,5 @@ + + + \ No newline at end of file diff --git a/ui-tests/jupyter_server_test_config.py b/ui-tests/jupyter_server_test_config.py index f2a9478..988ff44 100644 --- a/ui-tests/jupyter_server_test_config.py +++ b/ui-tests/jupyter_server_test_config.py @@ -4,9 +4,27 @@ opens the server to the world and provide access to JupyterLab JavaScript objects through the global window variable. """ +import os +from pathlib import Path from jupyterlab.galata import configure_jupyter_server configure_jupyter_server(c) +c.FileContentsManager.delete_to_trash = False + # Uncomment to set server log level to debug level # c.ServerApp.log_level = "DEBUG" + +# Link two test databases in the application +test_db_path = Path(__file__).parent.parent / "jupyter_sql_cell" / "tests" / "data" +databases = [] +for file in os.listdir(test_db_path): + file_path = test_db_path / file + databases.append({ + "database":str(test_db_path / file), + "dbms":"sqlite", + "driver":"aiosqlite", + "alias":file_path.stem + }) + +c.JupyterSqlCell.databases = databases diff --git a/ui-tests/tests/jupyter_sql_cell.spec.ts b/ui-tests/tests/jupyter_sql_cell.spec.ts index 1ef8b62..5a8d62b 100644 --- a/ui-tests/tests/jupyter_sql_cell.spec.ts +++ b/ui-tests/tests/jupyter_sql_cell.spec.ts @@ -1,8 +1,36 @@ -import { expect, galata, test } from '@jupyterlab/galata'; +import { + IJupyterLabPageFixture, + expect, + galata, + test +} from '@jupyterlab/galata'; +import { Locator } from '@playwright/test'; import * as path from 'path'; const fileName = 'simple.ipynb'; +async function openSidePanel(page: IJupyterLabPageFixture): Promise { + const tabBar = page.locator('.jp-SideBar.jp-mod-left'); + const button = tabBar?.locator('li[title="Databases"]'); + await button.click(); + const content = page.locator( + '#jp-left-stack .lm-StackedPanel-child:not(.lm-mod-hidden)' + ); + await expect(content).toHaveClass(/jp-sqlcell-databases-panel/); + return content; +} + +async function switchCellToSql( + page: IJupyterLabPageFixture, + index: number +): Promise { + await page.notebook.setCellType(index, 'raw'); + await (await page.notebook.getCellInput(index))?.click(); + await page + .locator('.jp-cell-toolbar [data-command="jupyter-sql-cell:switch"]') + .click(); +} + test.describe('cell toolbar', () => { test.beforeEach(async ({ page, request, tmpPath }) => { const contents = galata.newContentsHelper(request); @@ -95,10 +123,7 @@ test.describe('cell factory', () => { await expect(cells.nth(1)).not.toHaveClass(/jp-SqlCell/); await expect(cells.nth(2)).not.toHaveClass(/jp-SqlCell/); - await (await page.notebook.getCellInput(2))?.click(); - await page - .locator('.jp-cell-toolbar [data-command="jupyter-sql-cell:switch"]') - .click(); + await switchCellToSql(page, 2); await expect(cells.nth(2)).toHaveClass(/jp-SqlCell/); }); @@ -129,3 +154,195 @@ test.describe('cell factory', () => { ).toBe(null); }); }); + +test.describe('sidebar', () => { + test('There should be a database button on side panel', async ({ page }) => { + const tabBar = await page.sidebar.getTabBar('left'); + const button = await tabBar?.$('li[title="Databases"]'); + expect(button).not.toBeNull(); + expect(await button?.screenshot()).toMatchSnapshot('sidebar_icon.png'); + }); + + test('Side panel should have two database', async ({ page }) => { + const sidepanel = await openSidePanel(page); + const titles = sidepanel.locator('.jp-AccordionPanel-title'); + expect(titles).toHaveCount(2); + }); + + test('Should display tables list', async ({ page }) => { + const sidepanel = await openSidePanel(page); + const title = sidepanel.locator( + '.jp-AccordionPanel-title[aria-label="world Section"]' + ); + await title.locator('.lm-AccordionPanel-titleLabel').click(); + expect(await title.getAttribute('aria-expanded')).toBe('true'); + const tables = sidepanel.locator('.jp-sqlcell-table-title'); + expect(tables).toHaveCount(1); + expect(await tables.first().textContent()).toBe('world'); + }); + + test('Should display columns list', async ({ page }) => { + const sidepanel = await openSidePanel(page); + await sidepanel + .locator('.jp-AccordionPanel-title[aria-label="world Section"]') + .locator('.lm-AccordionPanel-titleLabel') + .click(); + const table = sidepanel.locator('.jp-sqlcell-table-title'); + await table.click(); + await expect(table).toHaveAttribute('aria-expanded', 'true'); + const columns = sidepanel.locator('.jp-sqlcell-column-items li'); + expect(columns).toHaveCount(35); + expect(columns.first()).toContainText('Abbreviation'); + }); +}); + +test.describe('connect database to cell', () => { + test.beforeEach(async ({ page, request, tmpPath }) => { + const contents = galata.newContentsHelper(request); + await contents.uploadFile( + path.resolve(__dirname, `./notebooks/${fileName}`), + `${tmpPath}/${fileName}` + ); + await page.notebook.openByPath(`${tmpPath}/${fileName}`); + await page.notebook.activate(fileName); + }); + + test.afterEach(async ({ request, tmpPath }) => { + const contents = galata.newContentsHelper(request); + await contents.deleteDirectory(tmpPath); + }); + + test('Connect button should be enable for SQL cell only', async ({ + page + }) => { + const sidepanel = await openSidePanel(page); + const button = sidepanel + .locator('.jp-AccordionPanel-title[aria-label="world Section"]') + .locator('.jp-sqlcell-database-selectbutton'); + + expect(button).toHaveAttribute('aria-disabled', 'true'); + expect(button).toHaveAttribute('aria-pressed', 'false'); + + expect(await button?.screenshot()).toMatchSnapshot( + 'connect_button_disabled.png' + ); + await (await page.notebook.getCell(1))?.click(); + expect(button).toHaveAttribute('aria-disabled', 'true'); + expect(button).toHaveAttribute('aria-pressed', 'false'); + + await (await page.notebook.getCell(2))?.click(); + expect(button).toHaveAttribute('aria-disabled', 'true'); + expect(button).toHaveAttribute('aria-pressed', 'false'); + + await switchCellToSql(page, 2); + expect(button).toHaveAttribute('aria-disabled', 'false'); + expect(button).toHaveAttribute('aria-pressed', 'false'); + expect(await button?.screenshot()).toMatchSnapshot( + 'connect_button_enabled.png' + ); + }); + + test('Connect button should be pressed on click', async ({ page }) => { + const sidepanel = await openSidePanel(page); + const button = sidepanel + .locator('.jp-AccordionPanel-title[aria-label="world Section"]') + .locator('.jp-sqlcell-database-selectbutton'); + + await switchCellToSql(page, 2); + expect(button).toHaveAttribute('aria-pressed', 'false'); + + await button.click(); + expect(button).toHaveAttribute('aria-pressed', 'true'); + expect(await button?.screenshot()).toMatchSnapshot( + 'connect_button_pressed.png' + ); + }); + + test('Should connect a database to a cell', async ({ page, tmpPath }) => { + const sidepanel = await openSidePanel(page); + const button = sidepanel + .locator('.jp-AccordionPanel-title[aria-label="world Section"]') + .locator('.jp-sqlcell-database-selectbutton'); + + await page.notebook.setCell(2, 'raw', 'SELECT * FROM world'); + await switchCellToSql(page, 2); + await button.click(); + + const execute = page.locator( + '.jp-cell-toolbar [data-command="jupyter-sql-cell:execute"]' + ); + await execute.click(); + + await page.sidebar.openTab('filebrowser'); + const files = page.locator('li.jp-DirListing-item'); + await page.filebrowser.openDirectory(`${tmpPath}/_sql_output`); + + await expect(files).toHaveCount(1); + expect(files.first()).toHaveAttribute('data-file-type', 'csv'); + }); + + test('Should not create an output file with wrong query', async ({ + page, + tmpPath + }) => { + const sidepanel = await openSidePanel(page); + const button = sidepanel + .locator('.jp-AccordionPanel-title[aria-label="world Section"]') + .locator('.jp-sqlcell-database-selectbutton'); + + await page.notebook.setCell(2, 'raw', 'SELECT * FROM albums'); + await switchCellToSql(page, 2); + await button.click(); + + const execute = page.locator( + '.jp-cell-toolbar [data-command="jupyter-sql-cell:execute"]' + ); + await execute.click(); + + await page.sidebar.openTab('filebrowser'); + const opening = await page.filebrowser.openDirectory( + `${tmpPath}/_sql_output` + ); + expect(opening).toBeFalsy(); + }); + + test('Should connect several databases to several cells', async ({ + page, + tmpPath + }) => { + const sidepanel = await openSidePanel(page); + const button1 = sidepanel + .locator('.jp-AccordionPanel-title[aria-label="world Section"]') + .locator('.jp-sqlcell-database-selectbutton'); + const button2 = sidepanel + .locator('.jp-AccordionPanel-title[aria-label="chinook Section"]') + .locator('.jp-sqlcell-database-selectbutton'); + + await page.notebook.setCell(1, 'raw', 'SELECT * FROM world'); + await switchCellToSql(page, 1); + await button1.click(); + + await page.notebook.setCell(2, 'raw', 'SELECT * FROM albums'); + await switchCellToSql(page, 2); + await button2.click(); + + const execute = page.locator( + '.jp-cell-toolbar [data-command="jupyter-sql-cell:execute"]' + ); + await (await page.notebook.getCellInput(1))?.click(); + await execute.click(); + + await page.sidebar.openTab('filebrowser'); + const files = page.locator('li.jp-DirListing-item'); + + await page.filebrowser.openDirectory(`${tmpPath}/_sql_output`); + await expect(files).toHaveCount(1); + + await page.filebrowser.openDirectory(tmpPath); + + await (await page.notebook.getCellInput(2))?.click(); + await execute.click(); + await page.filebrowser.openDirectory(`${tmpPath}/_sql_output`); + await expect(files).toHaveCount(2); + }); +}); diff --git a/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-disabled-linux.png b/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-disabled-linux.png new file mode 100644 index 0000000..efcb050 Binary files /dev/null and b/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-disabled-linux.png differ diff --git a/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-enabled-linux.png b/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-enabled-linux.png new file mode 100644 index 0000000..65c331d Binary files /dev/null and b/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-enabled-linux.png differ diff --git a/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-pressed-linux.png b/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-pressed-linux.png new file mode 100644 index 0000000..0031acd Binary files /dev/null and b/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/connect-button-pressed-linux.png differ diff --git a/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/sidebar-icon-linux.png b/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/sidebar-icon-linux.png new file mode 100644 index 0000000..d4ba2b8 Binary files /dev/null and b/ui-tests/tests/jupyter_sql_cell.spec.ts-snapshots/sidebar-icon-linux.png differ