Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BiFunction with ToolContext Fails Due to Incorrect Type Resolution in Spring Cloud Function #1576

Open
sfcodes opened this issue Oct 21, 2024 · 0 comments

Comments

@sfcodes
Copy link

sfcodes commented Oct 21, 2024

Relevant versions:

  • Spring AI version 1.0.0-SNAPSHOT
  • org.springframework.cloud:spring-cloud-function-context 4.1.3
  • org.springframework.boot:spring-boot 3.3.4
  • io.spring.dependency-management 1.1.6

(Basically using the latest Spring GA release train without any unusual dependency overrides)

Problem Description:

Using a BiFunction with ToolContext, as described in the Spring AI documentation, does not work as expected. When attempting to call a model that utilizes a BiFunction, the following exception is thrown:

org.springframework.ai.retry.NonTransientAiException: 400 - {
  "error": {
    "message": "Invalid schema for function 'getWeatherFunction': schema must be a JSON Schema of 'type: \"object\"', got 'type: \"None\"'.",
    "type": "invalid_request_error",
    "param": "tools[1].function.parameters",
    "code": "invalid_function_parameters"
  }
}
	at org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration$2.handleError(SpringAiRetryAutoConfiguration.java:95) ~[spring-ai-spring-boot-autoconfigure-1.0.0-SNAPSHOT.jar:1.0.0-SNAPSHOT]
	at org.springframework.web.client.ResponseErrorHandler.handleError(ResponseErrorHandler.java:63) ~[spring-web-6.1.13.jar:6.1.13]
	at org.springframework.web.client.StatusHandler.lambda$fromErrorHandler$1(StatusHandler.java:71) ~[spring-web-6.1.13.jar:6.1.13]
	at org.springframework.web.client.StatusHandler.handle(StatusHandler.java:146) ~[spring-web-6.1.13.jar:6.1.13]
	at org.springframework.web.client.DefaultRestClient$DefaultResponseSpec.applyStatusHandlers(DefaultRestClient.java:707) ~[spring-web-6.1.13.jar:6.1.13]
	at org.springframework.web.client.DefaultRestClient.readWithMessageConverters(DefaultRestClient.java:200) ~[spring-web-6.1.13.jar:6.1.13]
	at org.springframework.web.client.DefaultRestClient$DefaultResponseSpec.readBody(DefaultRestClient.java:694) ~[spring-web-6.1.13.jar:6.1.13]
	at org.springframework.web.client.DefaultRestClient$DefaultResponseSpec.toEntityInternal(DefaultRestClient.java:664) ~[spring-web-6.1.13.jar:6.1.13]
	at org.springframework.web.client.DefaultRestClient$DefaultResponseSpec.toEntity(DefaultRestClient.java:653) ~[spring-web-6.1.13.jar:6.1.13]
	at org.springframework.ai.openai.api.OpenAiApi.chatCompletionEntity(OpenAiApi.java:1040) ~[spring-ai-openai-1.0.0-SNAPSHOT.jar:1.0.0-SNAPSHOT]
	at org.springframework.ai.openai.OpenAiChatModel.lambda$call$1(OpenAiChatModel.java:227) ~[spring-ai-openai-1.0.0-SNAPSHOT.jar:1.0.0-SNAPSHOT]
	at org.springframework.retry.support.RetryTemplate.doExecute(RetryTemplate.java:344) ~[spring-retry-2.0.9.jar:?]
	at org.springframework.retry.support.RetryTemplate.execute(RetryTemplate.java:217) ~[spring-retry-2.0.9.jar:?]
	at org.springframework.ai.openai.OpenAiChatModel.lambda$call$3(OpenAiChatModel.java:227) ~[spring-ai-openai-1.0.0-SNAPSHOT.jar:1.0.0-SNAPSHOT]
	at io.micrometer.observation.Observation.observe(Observation.java:565) ~[micrometer-observation-1.13.4.jar:1.13.4]
	at org.springframework.ai.openai.OpenAiChatModel.call(OpenAiChatModel.java:224) ~[spring-ai-openai-1.0.0-SNAPSHOT.jar:1.0.0-SNAPSHOT]
	
(the rest of the stack is omitted for brevity, but basically the call starts from openAiChatModel.call(prompt)

Reproduction Steps:

1. Define a Function that works:

@Component
@Description("Get the population of a city")
public class GetPopulationFunction implements Function<GetPopulationFunction.Request, GetPopulationFunction.Response> {

    public record Request(String city) {}
    public record Response(int population) {}

    @Override
    public Response apply(Request request) {
        return switch (request.city) {
            case "San Francisco" -> new Response(788_478);
            case "Tokyo" -> new Response(14_187_176);
            case "Paris" -> new Response(2_102_650);
            default -> null;
        };
    }
}

This function works as expected when called.

2. Define a BiFunction with ToolContext—which you'll see fails:

@Component
@Description("Get the weather in a location")
public class GetWeatherFunction implements BiFunction<GetWeatherFunction.Request, ToolContext, GetWeatherFunction.Response> {

    public enum Unit { C, F }

    public record Request(String location, Unit unit) {}
    public record Response(double temp, Unit unit) {}

    @Override
    public Response apply(Request request, ToolContext toolContext) {
        return switch (request.location) {
            case "San Francisco" -> new Response(55, Unit.C);
            case "Tokyo" -> new Response(65, Unit.C);
            case "Paris" -> new Response(67, Unit.C);
            default -> null;
        };
    }
}

3. Attempt to invoke a model with this BiFunction:

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
OpenAiChatOptions chatOptions = OpenAiChatOptions.builder()
        .withModel(OpenAiApi.ChatModel.GPT_4)
        .withFunctions(Set.of(
                "getPopulationFunction",
                "getWeatherFunction"
        ))
        .build();

Prompt prompt = new Prompt(userMessage, chatOptions);

ChatResponse response = openAiChatModel.call(prompt);

4. Observe the Exception:

The call fails with the earlier mentioned NonTransientAiException.

Possible Solution:

It appears the issue might not actually be in your code, but rather in Spring Cloud Function Context. It appears that org.springframework.cloud.function.context.catalog.FunctionTypeUtils.discoverFunctionTypeFromClass(Class<?> functionalClass) incorrectly reify's BiFunction's. I think it's missing something like this at the end of that function—to work properly with BiFunction's:

else if (BiFunction.class.isAssignableFrom(functionalClass)) {
    return TypeResolver.reify(BiFunction.class, (Class<BiFunction<?, ?, ?>>) functionalClass);
}

Now because of this issue, your code ends up extracting the input type incorrectly at org.springframework.ai.model.function.FunctionCallbackContext.getFunctionCallback(String beanName, String defaultDescription) line 89

You can easily see this by putting a breakpoint at that spot and inspect functionInputClass; while it should be whatever input type the BiFunction declares, instead—incorrectly—it's a java.lang.Object.
image

So I believe this issue is due to a bug in the Spring Cloud Function Context library regarding BiFunction type resolution. However before reaching out to the that team, I wanted to run this by you and see what you think. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant