-
-
Notifications
You must be signed in to change notification settings - Fork 200
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1468 from schemacrawler/embeddings
Allow SchemaCrawler ChatGPT plugin to do Retrieval Augmented Generation (RAG) for optimal help with SQL queries
- Loading branch information
Showing
48 changed files
with
1,376 additions
and
423 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,58 @@ | ||
/* | ||
======================================================================== | ||
SchemaCrawler | ||
http://www.schemacrawler.com | ||
Copyright (c) 2000-2024, Sualeh Fatehi <[email protected]>. | ||
All rights reserved. | ||
------------------------------------------------------------------------ | ||
SchemaCrawler is distributed in the hope that it will be useful, but | ||
WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. | ||
SchemaCrawler and the accompanying materials are made available under | ||
the terms of the Eclipse Public License v1.0, GNU General Public License | ||
v3 or GNU Lesser General Public License v3. | ||
You may elect to redistribute this code under any of these licenses. | ||
The Eclipse Public License is available at: | ||
http://www.eclipse.org/legal/epl-v10.html | ||
The GNU General Public License v3 and the GNU Lesser General Public | ||
License v3 are available at: | ||
http://www.gnu.org/licenses/ | ||
======================================================================== | ||
*/ | ||
|
||
package schemacrawler.tools.command.chatgpt; | ||
|
||
import static com.theokanning.openai.completion.chat.ChatMessageRole.FUNCTION; | ||
import static com.theokanning.openai.completion.chat.ChatMessageRole.USER; | ||
import static java.util.Objects.requireNonNull; | ||
import static schemacrawler.tools.command.chatgpt.utility.ChatGPTUtility.isExitCondition; | ||
import static schemacrawler.tools.command.chatgpt.utility.ChatGPTUtility.printResponse; | ||
|
||
import java.sql.Connection; | ||
import java.time.Duration; | ||
import java.util.ArrayList; | ||
import java.util.Collection; | ||
import java.util.List; | ||
import java.util.Scanner; | ||
import java.util.logging.Level; | ||
import java.util.logging.Logger; | ||
|
||
import com.theokanning.openai.completion.chat.ChatCompletionRequest; | ||
import com.theokanning.openai.completion.chat.ChatCompletionRequest.ChatCompletionRequestFunctionCall; | ||
import com.theokanning.openai.completion.chat.ChatCompletionResult; | ||
import com.theokanning.openai.completion.chat.ChatFunctionCall; | ||
import com.theokanning.openai.completion.chat.ChatMessage; | ||
import com.theokanning.openai.service.FunctionExecutor; | ||
import com.theokanning.openai.service.OpenAiService; | ||
|
||
import static java.util.Objects.requireNonNull; | ||
import schemacrawler.schema.Catalog; | ||
import schemacrawler.tools.command.chatgpt.embeddings.QueryService; | ||
import schemacrawler.tools.command.chatgpt.options.ChatGPTCommandOptions; | ||
import schemacrawler.tools.command.chatgpt.utility.ChatGPTUtility; | ||
import schemacrawler.tools.command.chatgpt.utility.ChatHistory; | ||
import us.fatehi.utility.string.StringFormat; | ||
|
||
public final class ChatGPTConsole implements AutoCloseable { | ||
|
@@ -36,7 +64,9 @@ public final class ChatGPTConsole implements AutoCloseable { | |
private final ChatGPTCommandOptions commandOptions; | ||
private final FunctionExecutor functionExecutor; | ||
private final OpenAiService service; | ||
private final QueryService queryService; | ||
private final ChatHistory chatHistory; | ||
private final boolean useMetadata; | ||
|
||
public ChatGPTConsole( | ||
final ChatGPTCommandOptions commandOptions, | ||
|
@@ -51,13 +81,11 @@ public ChatGPTConsole( | |
final Duration timeout = Duration.ofSeconds(commandOptions.getTimeout()); | ||
service = new OpenAiService(commandOptions.getApiKey(), timeout); | ||
|
||
final List<ChatMessage> systemMessages; | ||
if (commandOptions.isUseMetadata()) { | ||
systemMessages = ChatGPTUtility.systemMessages(catalog, connection); | ||
} else { | ||
systemMessages = new ArrayList<>(); | ||
} | ||
chatHistory = new ChatHistory(commandOptions.getContext(), systemMessages); | ||
queryService = new QueryService(service); | ||
queryService.addTables(catalog.getTables()); | ||
|
||
useMetadata = commandOptions.isUseMetadata(); | ||
chatHistory = new ChatHistory(commandOptions.getContext(), new ArrayList<>()); | ||
} | ||
|
||
@Override | ||
|
@@ -90,11 +118,19 @@ private List<ChatMessage> complete(final String prompt) { | |
final List<ChatMessage> completions = new ArrayList<>(); | ||
|
||
try { | ||
|
||
final ChatMessage userMessage = new ChatMessage(USER.value(), prompt); | ||
chatHistory.add(userMessage); | ||
|
||
final List<ChatMessage> messages = chatHistory.toList(); | ||
LOGGER.log(Level.CONFIG, new StringFormat("ChatGPT request:%n%s", messages)); | ||
|
||
if (useMetadata) { | ||
final Collection<ChatMessage> chatMessages = queryService.query(prompt); | ||
for (final ChatMessage chatMessage : chatMessages) { | ||
messages.add(chatMessage); | ||
} | ||
} | ||
|
||
final ChatCompletionRequest completionRequest = | ||
ChatCompletionRequest.builder() | ||
.messages(messages) | ||
|
@@ -103,6 +139,7 @@ private List<ChatMessage> complete(final String prompt) { | |
.model(commandOptions.getModel()) | ||
.n(1) | ||
.build(); | ||
logChatRequest(completionRequest.getMessages(), completionRequest.getFunctions()); | ||
|
||
final ChatCompletionResult chatCompletion = service.createChatCompletion(completionRequest); | ||
LOGGER.log(Level.INFO, new StringFormat("Token usage: %s", chatCompletion.getUsage())); | ||
|
@@ -127,4 +164,24 @@ private List<ChatMessage> complete(final String prompt) { | |
|
||
return completions; | ||
} | ||
|
||
private void logChatRequest(final List<ChatMessage> messages, final List<?> functions) { | ||
final Level level = Level.CONFIG; | ||
if (!LOGGER.isLoggable(level)) { | ||
return; | ||
} | ||
final StringBuilder buffer = new StringBuilder(); | ||
buffer.append("ChatGPT request:").append(System.lineSeparator()); | ||
if (messages != null) { | ||
for (final ChatMessage message : messages) { | ||
buffer.append(message).append(System.lineSeparator()); | ||
} | ||
} | ||
if (functions != null) { | ||
for (final Object function : functions) { | ||
buffer.append(function).append(System.lineSeparator()); | ||
} | ||
} | ||
LOGGER.log(level, buffer.toString()); | ||
} | ||
} |
96 changes: 96 additions & 0 deletions
96
...r-chatgpt/src/main/java/schemacrawler/tools/command/chatgpt/embeddings/EmbeddedTable.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* | ||
======================================================================== | ||
SchemaCrawler | ||
http://www.schemacrawler.com | ||
Copyright (c) 2000-2024, Sualeh Fatehi <[email protected]>. | ||
All rights reserved. | ||
------------------------------------------------------------------------ | ||
SchemaCrawler is distributed in the hope that it will be useful, but | ||
WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. | ||
SchemaCrawler and the accompanying materials are made available under | ||
the terms of the Eclipse Public License v1.0, GNU General Public License | ||
v3 or GNU Lesser General Public License v3. | ||
You may elect to redistribute this code under any of these licenses. | ||
The Eclipse Public License is available at: | ||
http://www.eclipse.org/legal/epl-v10.html | ||
The GNU General Public License v3 and the GNU Lesser General Public | ||
License v3 are available at: | ||
http://www.gnu.org/licenses/ | ||
======================================================================== | ||
*/ | ||
|
||
package schemacrawler.tools.command.chatgpt.embeddings; | ||
|
||
import static java.util.Objects.requireNonNull; | ||
import schemacrawler.schema.NamedObject; | ||
import schemacrawler.schema.NamedObjectKey; | ||
import schemacrawler.schema.Schema; | ||
import schemacrawler.schema.Table; | ||
import schemacrawler.tools.command.serialize.model.CompactCatalogUtility; | ||
import schemacrawler.tools.command.serialize.model.TableDocument; | ||
|
||
public final class EmbeddedTable implements NamedObject { | ||
|
||
private static final long serialVersionUID = 5216101777323983303L; | ||
|
||
private final Table table; | ||
private final TableDocument tableDocument; | ||
private TextEmbedding embedding; | ||
|
||
EmbeddedTable(final Table table) { | ||
this.table = requireNonNull(table, "No table provided"); | ||
tableDocument = CompactCatalogUtility.getTableDocument(table, false); | ||
} | ||
|
||
@Override | ||
public int compareTo(final NamedObject object) { | ||
return table.compareTo(object); | ||
} | ||
|
||
public TextEmbedding getEmbedding() { | ||
return embedding; | ||
} | ||
|
||
@Override | ||
public String getFullName() { | ||
return table.getFullName(); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return table.getName(); | ||
} | ||
|
||
public Schema getSchema() { | ||
return table.getSchema(); | ||
} | ||
|
||
public boolean hasEmbedding() { | ||
return embedding != null; | ||
} | ||
|
||
@Override | ||
public NamedObjectKey key() { | ||
return table.key(); | ||
} | ||
|
||
public String toJson() { | ||
return tableDocument.toJson(); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return getFullName(); | ||
} | ||
|
||
void setEmbedding(final TextEmbedding providedEmmedding) { | ||
embedding = providedEmmedding; | ||
} | ||
} |
Oops, something went wrong.