-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils_sql.py
247 lines (210 loc) · 8.83 KB
/
utils_sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import sqlparse
import pandas as pd
import json
import numpy as np
import ast
import re
def sql_to_list(sql_buffer):
"""
Split a buffer of SQL statements into a list of individual commands.
Args:
sql_buffer (str): The buffer containing SQL commands.
Returns:
list: A list of individual SQL commands.
"""
try:
if not isinstance(sql_buffer, str):
raise ValueError("Input must be a string")
# Use sqlparse's split method to split SQL commands
# This handles cases like semicolons within strings or comments
sql_commands = sqlparse.split(sql_buffer)
# Remove any leading or trailing whitespace from each command
sql_commands = [cmd.strip() for cmd in sql_commands if cmd.strip()]
return sql_commands
except ValueError as e:
print(f"ValueError splitting SQL commands: {e}")
return []
except RuntimeError as e:
print(f"Unexpected error splitting SQL commands: {e}")
return []
def identify_vector_columns(df, sample_size=1000):
"""
Identify columns in the DataFrame that contain vectors (lists or arrays of numbers).
Args:
df (pandas.DataFrame): The DataFrame to inspect.
sample_size (int): The number of rows to sample for detecting vector columns.
Returns:
dict: A dictionary where the keys are column names and the values are the length of the vectors.
"""
vector_columns = {}
for column in df.columns:
if df[column].dtype == 'object':
sample = df[column].dropna().head(sample_size)
if len(sample) > 0:
first_val = sample.iloc[0]
if isinstance(first_val, str):
try:
parsed = ast.literal_eval(first_val)
if isinstance(parsed, (list, np.ndarray)) and all(isinstance(x, (int, float)) for x in parsed):
vector_columns[column] = len(parsed)
except (ValueError, SyntaxError):
pass
elif isinstance(first_val, (list, np.ndarray)):
if all(isinstance(x, (int, float)) for x in first_val):
vector_columns[column] = len(first_val)
return vector_columns
def is_json(value):
"""
Check if a given string value is a valid JSON.
Args:
value (str): The string to check.
Returns:
bool: True if the string is valid JSON, False otherwise.
"""
if not isinstance(value, str):
return False
try:
json.loads(value)
return True
except json.JSONDecodeError:
return False
def sanitize_value(value):
"""
Sanitize a value for safe inclusion in SQL statements.
Args:
value: The value to sanitize (can be str, dict, or other types).
Returns:
str: The sanitized value as a string.
"""
if isinstance(value, str):
return value.replace("'", "''").replace("\\", "\\\\")
elif isinstance(value, dict):
return json.dumps(value).replace("'", "''").replace("\\", "\\\\")
else:
return str(value).replace("'", "''").replace("\\", "\\\\")
def generate_create_table_sql(df, table_name):
"""
Generate a SQL CREATE TABLE statement based on the DataFrame structure.
Args:
df (pandas.DataFrame): The DataFrame object from which to generate the SQL statement.
table_name (str): The name of the table to be created.
Returns:
str: A SQL CREATE TABLE statement.
"""
# Replace spaces with underscores in column names
df.columns = df.columns.str.replace(' ', '_').str.replace('.', '_')
vector_columns = identify_vector_columns(df)
columns = []
for column_name, dtype in df.dtypes.items():
if column_name in vector_columns:
columns.append(f"{column_name} F32_BLOB({vector_columns[column_name]})")
elif dtype == 'object':
if df[column_name].apply(lambda x: isinstance(x, bytes)).all():
columns.append(f"{column_name} BLOB")
else:
columns.append(f"{column_name} TEXT")
elif dtype == 'int64':
columns.append(f"{column_name} INTEGER")
elif dtype == 'float64':
columns.append(f"{column_name} REAL")
elif dtype == 'bool':
columns.append(f"{column_name} INTEGER")
elif dtype == 'datetime64[ns]' or dtype == 'datetime64[ns, UTC]':
columns.append(f"{column_name} TIMESTAMP")
else:
columns.append(f"{column_name} TEXT")
sql = f"CREATE TABLE {table_name} (\n"
sql += ",\n".join(columns)
sql += "\n);"
return sql
def generate_insert_sql(df, table_name, raw_columns=None, exclude_columns=None):
"""
Generate SQL INSERT statements based on the DataFrame data.
Args:
df (pandas.DataFrame): The DataFrame containing data to be inserted.
table_name (str): The name of the table into which data will be inserted.
raw_columns (list, optional): List of column names whose values are raw SQL expressions.
These values will not be quoted or modified. Defaults to None.
exclude_columns (list, optional): List of column names to exclude from the INSERT. Defaults to None.
Returns:
list: A list of SQL INSERT statements.
"""
if raw_columns is None:
raw_columns = []
if exclude_columns is None:
exclude_columns = []
# Exclude specific columns
df = df.drop(columns=exclude_columns, errors='ignore')
sql_commands = []
df.columns = df.columns.str.replace(' ', '_').str.replace('.', '_')
vector_columns = identify_vector_columns(df)
column_names = df.columns.tolist()
columns = ", ".join(column_names)
for row in df.itertuples(index=False):
values = []
for column_name, value in zip(column_names, row):
if column_name in raw_columns:
if pd.isnull(value):
values.append("NULL")
else:
values.append(f"{value}")
elif pd.isnull(value):
values.append("NULL")
elif isinstance(value, str):
sanitized_value = sanitize_value(value)
values.append(f"'{sanitized_value}'")
elif isinstance(value, pd.Timestamp):
values.append(f"'{value}'")
elif isinstance(value, (int, float)):
values.append(str(value))
elif isinstance(value, bool):
values.append('1' if value else '0')
elif isinstance(value, dict):
sanitized_value = sanitize_value(value)
values.append(f"'{sanitized_value}'")
else:
sanitized_value = sanitize_value(value)
values.append(f"'{sanitized_value}'")
values_str = ", ".join(values)
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({values_str});"
sql_commands.append(sql)
return sql_commands
def format_sql(statement, reindent=True, indent_width=4, keyword_case='upper'):
"""
Format SQL statement using sqlparse library.
Args:
statement (str): The SQL statement to be formatted.
reindent (bool, optional): Whether to reindent the SQL statement. Defaults to True.
keyword_case (str, optional): Case of keywords after formatting. Defaults to 'upper'.
Returns:
str: Formatted SQL statement, or original statement if formatting fails.
"""
try:
if not isinstance(statement, str):
raise ValueError("Statement must be a string")
formatted_statement = sqlparse.format(statement,
reindent=reindent,
reindent_aligned=True,
indent_columns=True,
indent_width=indent_width,
keyword_case=keyword_case)
return formatted_statement
except ValueError as e:
raise ValueError(f"ValueError formatting SQL: {e}") from e
except Exception as e:
raise ValueError(f"Unexpected error formatting SQL: {e}") from e
def get_autoincrement_columns(create_table_sql):
"""
Parse the CREATE TABLE statement and return a list of column names with AUTOINCREMENT.
Args:
create_table_sql (str): The CREATE TABLE SQL statement.
Returns:
list: List of column names that have AUTOINCREMENT.
"""
autoinc_columns = []
# Use regex to find columns with AUTOINCREMENT
# Assume syntax: column_name TYPE PRIMARY KEY AUTOINCREMENT
pattern = re.compile(r'\b(\w+)\s+\w+\s+PRIMARY\s+KEY\s+AUTOINCREMENT\b', re.IGNORECASE)
matches = pattern.findall(create_table_sql)
autoinc_columns.extend(matches)
return autoinc_columns