Skip to content

Commit

Permalink
feat: handle gpt o1-preview, o1-mini models
Browse files Browse the repository at this point in the history
  • Loading branch information
qtnx authored and Robitx committed Sep 17, 2024
1 parent f4cbbf4 commit af1a409
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions lua/gp/dispatcher.lua
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,29 @@ D.prepare_payload = function(messages, model, provider)
model.model = "gpt-4o-2024-05-13"
end

return {
local output = {
model = model.model,
stream = true,
messages = messages,
max_tokens = model.max_tokens or 4096,
temperature = math.max(0, math.min(2, model.temperature or 1)),
top_p = math.max(0, math.min(1, model.top_p or 1)),
}

if provider == "openai" and model.model:sub(1, 2) == "o1" then
for i = #messages, 1, -1 do
if messages[i].role == "system" then
table.remove(messages, i)
end
end
-- remove max_tokens, top_p, temperature for o1 models. https://platform.openai.com/docs/guides/reasoning/beta-limitations
output.max_tokens = nil
output.temperature = nil
output.top_p = nil
output.stream = false
end

return output
end

-- gpt query
Expand Down Expand Up @@ -249,6 +264,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end
end


if content and type(content) == "string" then
qt.response = qt.response .. content
handler(qid, content)
Expand Down Expand Up @@ -282,6 +298,19 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
if #buffer > 0 then
process_lines(buffer)
end
local raw_response = qt.raw_response
local content = qt.response
if qt.provider == 'openai' and content == "" and raw_response:match('choices') and raw_response:match("content") then
local response = vim.json.decode(raw_response)
if response.choices and response.choices[1] and response.choices[1].message and response.choices[1].message.content then
content = response.choices[1].message.content
end
if content and type(content) == "string" then
qt.response = qt.response .. content
handler(qid, content)
end
end


if qt.response == "" then
logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
Expand Down Expand Up @@ -363,7 +392,8 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
}
end

local temp_file = D.query_dir .. "/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
local temp_file = D.query_dir ..
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
helpers.table_to_file(payload, temp_file)

local curl_params = vim.deepcopy(D.config.curl_params or {})
Expand Down

0 comments on commit af1a409

Please sign in to comment.