-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
Co-authored-by: aries_ckt <[email protected]>
- Loading branch information
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
"""TuGraph Connector.""" | ||
import json | ||
from typing import Any, Dict, List, cast | ||
|
||
from .base import BaseConnector | ||
|
||
|
||
class TuGraphConnector(BaseConnector): | ||
"""TuGraph connector.""" | ||
|
||
db_type: str = "tugraph" | ||
driver: str = "bolt" | ||
dialect: str = "tugraph" | ||
|
||
def __init__(self, session): | ||
"""Initialize the connector with a Neo4j driver.""" | ||
self._session = session | ||
self._schema = None | ||
|
||
@classmethod | ||
def from_uri_db( | ||
cls, host: str, port: int, user: str, pwd: str, db_name: str, **kwargs: Any | ||
) -> "TuGraphConnector": | ||
"""Create a new TuGraphConnector from host, port, user, pwd, db_name.""" | ||
try: | ||
from neo4j import GraphDatabase | ||
|
||
db_url = f"{cls.driver}://{host}:{str(port)}" | ||
with GraphDatabase.driver(db_url, auth=(user, pwd)) as client: | ||
client.verify_connectivity() | ||
session = client.session(database=db_name) | ||
return cast(TuGraphConnector, cls(session=session)) | ||
except ImportError as err: | ||
raise ImportError("requests package is not installed") from err | ||
|
||
def get_table_names(self) -> Dict[str, List[str]]: | ||
"""Get all table names from the TuGraph database using the Neo4j driver.""" | ||
# Run the query to get vertex labels | ||
v_result = self._session.run("CALL db.vertexLabels()").data() | ||
v_data = [table_name["label"] for table_name in v_result] | ||
|
||
# Run the query to get edge labels | ||
e_result = self._session.run("CALL db.edgeLabels()").data() | ||
e_data = [table_name["label"] for table_name in e_result] | ||
return {"vertex_tables": v_data, "edge_tables": e_data} | ||
|
||
def get_grants(self): | ||
"""Get grants.""" | ||
return [] | ||
|
||
def get_collation(self): | ||
"""Get collation.""" | ||
return "UTF-8" | ||
|
||
def get_charset(self): | ||
"""Get character_set of current database.""" | ||
return "UTF-8" | ||
|
||
def table_simple_info(self): | ||
"""Get table simple info.""" | ||
return [] | ||
|
||
def close(self): | ||
"""Close the Neo4j driver.""" | ||
self._session.close() | ||
|
||
def run(self): | ||
"""Run GQL.""" | ||
return [] | ||
|
||
def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]: | ||
"""Get fields about specified graph. | ||
Args: | ||
table_name (str): table name (graph name) | ||
table_type (str): table type (vertex or edge) | ||
Returns: | ||
columns: List[Dict], which contains name: str, type: str, | ||
default_expression: str, is_in_primary_key: bool, comment: str | ||
eg:[{'name': 'id', 'type': 'int', 'default_expression': '', | ||
'is_in_primary_key': True, 'comment': 'id'}, ...] | ||
""" | ||
data = [] | ||
result = None | ||
if table_type == "vertex": | ||
result = self._session.run( | ||
f"CALL db.getVertexSchema('{table_name}')" | ||
).data() | ||
else: | ||
result = self._session.run(f"CALL db.getEdgeSchema('{table_name}')").data() | ||
schema_info = json.loads(result[0]["schema"]) | ||
for prop in schema_info.get("properties", []): | ||
prop_dict = { | ||
"name": prop["name"], | ||
"type": prop["type"], | ||
"default_expression": "", | ||
"is_in_primary_key": bool( | ||
"primary" in schema_info and prop["name"] == schema_info["primary"] | ||
), | ||
"comment": prop["name"], | ||
} | ||
data.append(prop_dict) | ||
return data | ||
|
||
def get_indexes(self, table_name: str, table_type: str = "vertex") -> List[Dict]: | ||
"""Get table indexes about specified table. | ||
Args: | ||
table_name:(str) table name | ||
table_type:(str)'vertex' | 'edge' | ||
Returns: | ||
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}] | ||
""" | ||
# [{'name':'id','column_names':['id']}] | ||
result = self._session.run( | ||
f"CALL db.listLabelIndexes('{table_name}','{table_type}')" | ||
).data() | ||
transformed_data = [] | ||
for item in result: | ||
new_dict = {"name": item["field"], "column_names": [item["field"]]} | ||
transformed_data.append(new_dict) | ||
return transformed_data | ||
|
||
@classmethod | ||
def is_graph_type(cls) -> bool: | ||
"""Return whether the connector is a graph database connector.""" | ||
return True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
"""Summary for rdbms database.""" | ||
|
||
from typing import TYPE_CHECKING, Dict, List, Optional | ||
|
||
from dbgpt._private.config import Config | ||
from dbgpt.datasource import BaseConnector | ||
from dbgpt.datasource.conn_tugraph import TuGraphConnector | ||
from dbgpt.rag.summary.db_summary import DBSummary | ||
|
||
if TYPE_CHECKING: | ||
from dbgpt.datasource.manages import ConnectorManager | ||
|
||
CFG = Config() | ||
|
||
|
||
class GdbmsSummary(DBSummary): | ||
"""Get graph db table summary template.""" | ||
|
||
def __init__( | ||
self, name: str, type: str, manager: Optional["ConnectorManager"] = None | ||
): | ||
"""Create a new RdbmsSummary.""" | ||
self.name = name | ||
self.type = type | ||
self.summary_template = "{table_name}({columns})" | ||
# self.v_summary_template = "{table_name}({columns})" | ||
self.tables = {} | ||
# self.tables_info = [] | ||
# self.vector_tables_info = [] | ||
|
||
# TODO: Don't use the global variable. | ||
db_manager = manager or CFG.local_db_manager | ||
if not db_manager: | ||
raise ValueError("Local db manage is not initialized.") | ||
self.db = db_manager.get_connector(name) | ||
|
||
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, | ||
collation:{collation}""".format( | ||
users=self.db.get_users(), | ||
grant=self.db.get_grants(), | ||
charset=self.db.get_charset(), | ||
collation=self.db.get_collation(), | ||
) | ||
tables = self.db.get_table_names() | ||
self.table_info_summaries = { | ||
"vertex_tables": [ | ||
self.get_table_summary(table_name, "vertex") | ||
for table_name in tables["vertex_tables"] | ||
], | ||
"edge_tables": [ | ||
self.get_table_summary(table_name, "edge") | ||
for table_name in tables["edge_tables"] | ||
], | ||
} | ||
|
||
def get_table_summary(self, table_name, table_type): | ||
"""Get table summary for table. | ||
example: | ||
table_name(column1(column1 comment),column2(column2 comment), | ||
column3(column3 comment) and index keys, and table comment: {table_comment}) | ||
""" | ||
return _parse_table_summary( | ||
self.db, self.summary_template, table_name, table_type | ||
) | ||
|
||
def table_summaries(self): | ||
"""Get table summaries.""" | ||
return self.table_info_summaries | ||
|
||
|
||
def _parse_db_summary( | ||
conn: BaseConnector, summary_template: str = "{table_name}({columns})" | ||
) -> List[str]: | ||
"""Get db summary for database.""" | ||
table_info_summaries = None | ||
if isinstance(conn, TuGraphConnector): | ||
table_names = conn.get_table_names() | ||
v_tables = table_names.get("vertex_tables", []) | ||
e_tables = table_names.get("edge_tables", []) | ||
table_info_summaries = [ | ||
_parse_table_summary(conn, summary_template, table_name, "vertex") | ||
for table_name in v_tables | ||
] + [ | ||
_parse_table_summary(conn, summary_template, table_name, "edge") | ||
for table_name in e_tables | ||
] | ||
else: | ||
table_info_summaries = [] | ||
|
||
return table_info_summaries | ||
|
||
|
||
def _format_column(column: Dict) -> str: | ||
"""Format a single column's summary.""" | ||
comment = column.get("comment", "") | ||
if column.get("is_in_primary_key"): | ||
comment += " Primary Key" if comment else "Primary Key" | ||
return f"{column['name']} ({comment})" if comment else column["name"] | ||
|
||
|
||
def _format_indexes(indexes: List[Dict]) -> str: | ||
"""Format index keys for table summary.""" | ||
return ", ".join( | ||
f"{index['name']}(`{', '.join(index['column_names'])}`)" for index in indexes | ||
) | ||
|
||
|
||
def _parse_table_summary( | ||
conn: TuGraphConnector, summary_template: str, table_name: str, table_type: str | ||
) -> str: | ||
"""Enhanced table summary function.""" | ||
columns = [ | ||
_format_column(column) for column in conn.get_columns(table_name, table_type) | ||
] | ||
column_str = ", ".join(columns) | ||
|
||
indexes = conn.get_indexes(table_name, table_type) | ||
index_str = _format_indexes(indexes) if indexes else "" | ||
|
||
table_str = summary_template.format(table_name=table_name, columns=column_str) | ||
if index_str: | ||
table_str += f", and index keys: {index_str}" | ||
try: | ||
comment = conn.get_table_comment(table_name) | ||
except Exception: | ||
comment = dict(text=None) | ||
if comment.get("text"): | ||
table_str += ( | ||
f", and table comment: {comment.get('text')}, this is a {table_type} table" | ||
) | ||
else: | ||
table_str += f", and table comment: this is a {table_type} table" | ||
return table_str |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -606,6 +606,7 @@ def all_datasource_requires(): | |
"pyhive", | ||
"thrift", | ||
"thrift_sasl", | ||
"neo4j", | ||
] | ||
|
||
|
||
|