Skip to content

Commit

Permalink
Merge pull request #1468 from schemacrawler/embeddings
Browse files Browse the repository at this point in the history
Allow SchemaCrawler ChatGPT plugin to do Retrieval Augmented Generation (RAG) for optimal help with SQL queries
  • Loading branch information
sualeh authored Feb 12, 2024
2 parents 4cb83a6 + a822940 commit de2aedd
Show file tree
Hide file tree
Showing 48 changed files with 1,376 additions and 423 deletions.
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
"GitHub.vscode-pull-request-github"
],
"editor.tabSize": 2,
"scm.showActionButton": false
"scm.showActionButton": false,
"java.compile.nullAnalysis.mode": "automatic",
"java.jdt.ls.vmargs": "-XX:+UseParallelGC -XX:GCTimeRatio=4 -XX:AdaptiveSizePolicyWeight=90 -Dsun.zip.disableMemoryMapping=true -Xmx2G -Xms100m -Xlog:disable"
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import static java.util.Objects.requireNonNull;
import schemacrawler.schema.Column;

class ColumnPointer extends DatabaseObjectReference<Column> {
final class ColumnPointer extends DatabaseObjectReference<Column> {

private static final long serialVersionUID = 122669483681884924L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import static java.util.Objects.requireNonNull;
import schemacrawler.schema.Function;

class FunctionPointer extends DatabaseObjectReference<Function> {
final class FunctionPointer extends DatabaseObjectReference<Function> {

private static final long serialVersionUID = -5166020646865781875L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import static java.util.Objects.requireNonNull;
import schemacrawler.schema.Table;

class TablePointer extends DatabaseObjectReference<Table> {
final class TablePointer extends DatabaseObjectReference<Table> {

private static final long serialVersionUID = 8940800217960888019L;

Expand Down
6 changes: 6 additions & 0 deletions schemacrawler-chatgpt/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
<artifactId>service</artifactId>
<version>0.18.2</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>


<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@

import java.util.logging.Level;
import java.util.logging.Logger;

import com.theokanning.openai.model.Model;
import com.theokanning.openai.service.OpenAiService;

import schemacrawler.tools.command.chatgpt.options.ChatGPTCommandOptions;
import schemacrawler.tools.executable.BaseSchemaCrawlerCommand;

/** SchemaCrawler command plug-in. */
public class ChatGPTCommand extends BaseSchemaCrawlerCommand<ChatGPTCommandOptions> {
public final class ChatGPTCommand extends BaseSchemaCrawlerCommand<ChatGPTCommandOptions> {

private static final Logger LOGGER = Logger.getLogger(ChatGPTCommand.class.getName());

Expand All @@ -51,15 +49,15 @@ protected ChatGPTCommand() {
@Override
public void checkAvailability() throws RuntimeException {
// Check that OpenAI API key works, and the model is available
final OpenAiService service = new OpenAiService(commandOptions.getApiKey());
final Model model = service.getModel(commandOptions.getModel());
final OpenAiService service = new OpenAiService(this.commandOptions.getApiKey());
final Model model = service.getModel(this.commandOptions.getModel());
LOGGER.log(Level.CONFIG, String.format("Using ChatGPT model:%n%s", model));
}

@Override
public void execute() {
try (ChatGPTConsole chatGPTConsole =
new ChatGPTConsole(commandOptions, catalog, connection); ) {
new ChatGPTConsole(this.commandOptions, this.catalog, this.connection); ) {
chatGPTConsole.console();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
package schemacrawler.tools.command.chatgpt;

import static schemacrawler.tools.executable.commandline.PluginCommand.newPluginCommand;

import schemacrawler.schemacrawler.exceptions.ExecutionRuntimeException;
import schemacrawler.tools.command.chatgpt.options.ChatGPTCommandOptions;
import schemacrawler.tools.command.chatgpt.options.ChatGPTCommandOptionsBuilder;
Expand All @@ -40,7 +39,7 @@
import schemacrawler.tools.options.OutputOptions;

/** SchemaCrawler command plug-in for ChatGPT. */
public class ChatGPTCommandProvider extends BaseCommandProvider {
public final class ChatGPTCommandProvider extends BaseCommandProvider {

public static final String DESCRIPTION_HEADER = "SchemaCrawler ChatGPT integration";

Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,58 @@
/*
========================================================================
SchemaCrawler
http://www.schemacrawler.com
Copyright (c) 2000-2024, Sualeh Fatehi <[email protected]>.
All rights reserved.
------------------------------------------------------------------------
SchemaCrawler is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
SchemaCrawler and the accompanying materials are made available under
the terms of the Eclipse Public License v1.0, GNU General Public License
v3 or GNU Lesser General Public License v3.
You may elect to redistribute this code under any of these licenses.
The Eclipse Public License is available at:
http://www.eclipse.org/legal/epl-v10.html
The GNU General Public License v3 and the GNU Lesser General Public
License v3 are available at:
http://www.gnu.org/licenses/
========================================================================
*/

package schemacrawler.tools.command.chatgpt;

import static com.theokanning.openai.completion.chat.ChatMessageRole.FUNCTION;
import static com.theokanning.openai.completion.chat.ChatMessageRole.USER;
import static java.util.Objects.requireNonNull;
import static schemacrawler.tools.command.chatgpt.utility.ChatGPTUtility.isExitCondition;
import static schemacrawler.tools.command.chatgpt.utility.ChatGPTUtility.printResponse;

import java.sql.Connection;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Scanner;
import java.util.logging.Level;
import java.util.logging.Logger;

import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionRequest.ChatCompletionRequestFunctionCall;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.FunctionExecutor;
import com.theokanning.openai.service.OpenAiService;

import static java.util.Objects.requireNonNull;
import schemacrawler.schema.Catalog;
import schemacrawler.tools.command.chatgpt.embeddings.QueryService;
import schemacrawler.tools.command.chatgpt.options.ChatGPTCommandOptions;
import schemacrawler.tools.command.chatgpt.utility.ChatGPTUtility;
import schemacrawler.tools.command.chatgpt.utility.ChatHistory;
import us.fatehi.utility.string.StringFormat;

public final class ChatGPTConsole implements AutoCloseable {
Expand All @@ -36,7 +64,9 @@ public final class ChatGPTConsole implements AutoCloseable {
private final ChatGPTCommandOptions commandOptions;
private final FunctionExecutor functionExecutor;
private final OpenAiService service;
private final QueryService queryService;
private final ChatHistory chatHistory;
private final boolean useMetadata;

public ChatGPTConsole(
final ChatGPTCommandOptions commandOptions,
Expand All @@ -51,13 +81,11 @@ public ChatGPTConsole(
final Duration timeout = Duration.ofSeconds(commandOptions.getTimeout());
service = new OpenAiService(commandOptions.getApiKey(), timeout);

final List<ChatMessage> systemMessages;
if (commandOptions.isUseMetadata()) {
systemMessages = ChatGPTUtility.systemMessages(catalog, connection);
} else {
systemMessages = new ArrayList<>();
}
chatHistory = new ChatHistory(commandOptions.getContext(), systemMessages);
queryService = new QueryService(service);
queryService.addTables(catalog.getTables());

useMetadata = commandOptions.isUseMetadata();
chatHistory = new ChatHistory(commandOptions.getContext(), new ArrayList<>());
}

@Override
Expand Down Expand Up @@ -90,11 +118,19 @@ private List<ChatMessage> complete(final String prompt) {
final List<ChatMessage> completions = new ArrayList<>();

try {

final ChatMessage userMessage = new ChatMessage(USER.value(), prompt);
chatHistory.add(userMessage);

final List<ChatMessage> messages = chatHistory.toList();
LOGGER.log(Level.CONFIG, new StringFormat("ChatGPT request:%n%s", messages));

if (useMetadata) {
final Collection<ChatMessage> chatMessages = queryService.query(prompt);
for (final ChatMessage chatMessage : chatMessages) {
messages.add(chatMessage);
}
}

final ChatCompletionRequest completionRequest =
ChatCompletionRequest.builder()
.messages(messages)
Expand All @@ -103,6 +139,7 @@ private List<ChatMessage> complete(final String prompt) {
.model(commandOptions.getModel())
.n(1)
.build();
logChatRequest(completionRequest.getMessages(), completionRequest.getFunctions());

final ChatCompletionResult chatCompletion = service.createChatCompletion(completionRequest);
LOGGER.log(Level.INFO, new StringFormat("Token usage: %s", chatCompletion.getUsage()));
Expand All @@ -127,4 +164,24 @@ private List<ChatMessage> complete(final String prompt) {

return completions;
}

private void logChatRequest(final List<ChatMessage> messages, final List<?> functions) {
final Level level = Level.CONFIG;
if (!LOGGER.isLoggable(level)) {
return;
}
final StringBuilder buffer = new StringBuilder();
buffer.append("ChatGPT request:").append(System.lineSeparator());
if (messages != null) {
for (final ChatMessage message : messages) {
buffer.append(message).append(System.lineSeparator());
}
}
if (functions != null) {
for (final Object function : functions) {
buffer.append(function).append(System.lineSeparator());
}
}
LOGGER.log(level, buffer.toString());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
========================================================================
SchemaCrawler
http://www.schemacrawler.com
Copyright (c) 2000-2024, Sualeh Fatehi <[email protected]>.
All rights reserved.
------------------------------------------------------------------------
SchemaCrawler is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
SchemaCrawler and the accompanying materials are made available under
the terms of the Eclipse Public License v1.0, GNU General Public License
v3 or GNU Lesser General Public License v3.
You may elect to redistribute this code under any of these licenses.
The Eclipse Public License is available at:
http://www.eclipse.org/legal/epl-v10.html
The GNU General Public License v3 and the GNU Lesser General Public
License v3 are available at:
http://www.gnu.org/licenses/
========================================================================
*/

package schemacrawler.tools.command.chatgpt.embeddings;

import static java.util.Objects.requireNonNull;
import schemacrawler.schema.NamedObject;
import schemacrawler.schema.NamedObjectKey;
import schemacrawler.schema.Schema;
import schemacrawler.schema.Table;
import schemacrawler.tools.command.serialize.model.CompactCatalogUtility;
import schemacrawler.tools.command.serialize.model.TableDocument;

public final class EmbeddedTable implements NamedObject {

private static final long serialVersionUID = 5216101777323983303L;

private final Table table;
private final TableDocument tableDocument;
private TextEmbedding embedding;

EmbeddedTable(final Table table) {
this.table = requireNonNull(table, "No table provided");
tableDocument = CompactCatalogUtility.getTableDocument(table, false);
}

@Override
public int compareTo(final NamedObject object) {
return table.compareTo(object);
}

public TextEmbedding getEmbedding() {
return embedding;
}

@Override
public String getFullName() {
return table.getFullName();
}

@Override
public String getName() {
return table.getName();
}

public Schema getSchema() {
return table.getSchema();
}

public boolean hasEmbedding() {
return embedding != null;
}

@Override
public NamedObjectKey key() {
return table.key();
}

public String toJson() {
return tableDocument.toJson();
}

@Override
public String toString() {
return getFullName();
}

void setEmbedding(final TextEmbedding providedEmmedding) {
embedding = providedEmmedding;
}
}
Loading

0 comments on commit de2aedd

Please sign in to comment.