From 2d9d5970a0d52df89b62f6a63cf6cd081e343fec Mon Sep 17 00:00:00 2001 From: Baptiste Adrien Date: Sat, 7 Dec 2024 18:17:37 +0100 Subject: [PATCH] feat: switch to flux (#63) --- package.json | 3 +- .../migration.sql | 5 + prisma/schema.prisma | 6 + src/app/(auth)/studio/[id]/page.tsx | 16 +-- .../predictions/[predictionId]/hd/route.ts | 29 ++-- .../[id]/predictions/[predictionId]/route.ts | 8 +- .../api/projects/[id]/predictions/route.ts | 55 +++++--- src/app/api/projects/[id]/prompter/route.ts | 27 ++-- src/app/api/projects/[id]/train/route.ts | 43 +++--- src/app/api/projects/route.ts | 26 ++-- src/components/home/Pricing.tsx | 8 +- src/components/pages/StudioPage.tsx | 4 +- src/components/projects/FormPayment.tsx | 4 +- src/components/projects/PromptPanel.tsx | 20 +-- src/components/projects/PromptsDrawer.tsx | 6 +- .../projects/shot/BuyShotButton.tsx | 3 - src/components/projects/shot/ShotCard.tsx | 20 --- src/contexts/project-context.tsx | 4 +- src/core/clients/openai.ts | 8 +- src/core/clients/replicate.ts | 5 + src/core/utils/assets.ts | 4 +- src/core/utils/predictions.ts | 2 +- yarn.lock | 126 +++++++++++++++--- 23 files changed, 262 insertions(+), 170 deletions(-) create mode 100644 prisma/migrations/20241204124349_add_version_project/migration.sql diff --git a/package.json b/package.json index 8cb5de7..b6f7318 100644 --- a/package.json +++ b/package.json @@ -41,7 +41,7 @@ "next-auth": "^4.24.3", "next-s3-upload": "^0.3.3", "nodemailer": "^6.9.0", - "openai": "^3.1.0", + "openai": "^4.75.0", "plaiceholder": "^2.5.0", "react": "18.2.0", "react-advanced-cropper": "^0.17.0", @@ -51,6 +51,7 @@ "react-medium-image-zoom": "^5.1.8", "react-parallax-tilt": "^1.7.77", "react-query": "^3.39.2", + "replicate": "^1.0.1", "sharp": "^0.31.2", "smartcrop-sharp": "^2.0.6", "stripe": "^11.1.0", diff --git a/prisma/migrations/20241204124349_add_version_project/migration.sql b/prisma/migrations/20241204124349_add_version_project/migration.sql new file mode 100644 index 0000000..f5b5dc6 --- /dev/null +++ b/prisma/migrations/20241204124349_add_version_project/migration.sql @@ -0,0 +1,5 @@ +-- CreateEnum +CREATE TYPE "ProjectVersion" AS ENUM ('V1', 'V2'); + +-- AlterTable +ALTER TABLE "Project" ADD COLUMN "version" "ProjectVersion" NOT NULL DEFAULT 'V1'; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 386502b..d4a5c34 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -84,6 +84,7 @@ model Project { credits Int @default(100) promptWizardCredits Int @default(20) Payment Payment[] + version ProjectVersion @default(V1) } model Shot { @@ -115,3 +116,8 @@ model Payment { @@map("payments") } + +enum ProjectVersion { + V1 + V2 +} \ No newline at end of file diff --git a/src/app/(auth)/studio/[id]/page.tsx b/src/app/(auth)/studio/[id]/page.tsx index 7dc310b..25d976c 100644 --- a/src/app/(auth)/studio/[id]/page.tsx +++ b/src/app/(auth)/studio/[id]/page.tsx @@ -1,5 +1,4 @@ import StudioPage from "@/components/pages/StudioPage"; -import replicateClient from "@/core/clients/replicate"; import db from "@/core/db"; import { getCurrentSessionRedirect } from "@/lib/sessions"; import { Metadata } from "next"; @@ -38,20 +37,7 @@ const Studio = async ({ params }: { params: { id: string } }) => { notFound(); } - const { data: model } = await replicateClient.get( - `https://api.replicate.com/v1/models/${process.env.REPLICATE_USERNAME}/${project.id}/versions/${project.modelVersionId}` - ); - - const hasImageInputAvailable = Boolean( - model.openapi_schema?.components?.schemas?.Input?.properties?.image?.title - ); - - return ( - - ); + return ; }; export default Studio; diff --git a/src/app/api/projects/[id]/predictions/[predictionId]/hd/route.ts b/src/app/api/projects/[id]/predictions/[predictionId]/hd/route.ts index 7345e66..8a5eebd 100644 --- a/src/app/api/projects/[id]/predictions/[predictionId]/hd/route.ts +++ b/src/app/api/projects/[id]/predictions/[predictionId]/hd/route.ts @@ -1,5 +1,5 @@ import { authOptions } from "@/app/api/auth/[...nextauth]/route"; -import replicateClient from "@/core/clients/replicate"; +import { replicate } from "@/core/clients/replicate"; import db from "@/core/db"; import { getServerSession } from "next-auth"; import { NextResponse } from "next/server"; @@ -32,9 +32,7 @@ export async function GET( ); } - const { data: prediction } = await replicateClient.get( - `https://api.replicate.com/v1/predictions/${shot.hdPredictionId}` - ); + const prediction = await replicate.predictions.get(shot.hdPredictionId!); if (prediction.output) { shot = await db.shot.update({ @@ -77,22 +75,19 @@ export async function POST( ); } - const { data } = await replicateClient.post( - `https://api.replicate.com/v1/predictions`, - { - input: { - image: shot.outputUrl, - upscale: 8, - face_upsample: true, - codeformer_fidelity: 1, - }, - version: process.env.REPLICATE_HD_VERSION_MODEL_ID, - } - ); + const prediction = await replicate.predictions.create({ + version: process.env.REPLICATE_HD_VERSION_MODEL_ID!, + input: { + image: shot.outputUrl, + upscale: 8, + face_upsample: true, + codeformer_fidelity: 1, + }, + }); shot = await db.shot.update({ where: { id: shot.id }, - data: { hdStatus: "PENDING", hdPredictionId: data.id }, + data: { hdStatus: "PENDING", hdPredictionId: prediction.id }, }); return NextResponse.json({ shot }); diff --git a/src/app/api/projects/[id]/predictions/[predictionId]/route.ts b/src/app/api/projects/[id]/predictions/[predictionId]/route.ts index ebcf3a6..c00cd5b 100644 --- a/src/app/api/projects/[id]/predictions/[predictionId]/route.ts +++ b/src/app/api/projects/[id]/predictions/[predictionId]/route.ts @@ -1,5 +1,5 @@ import { authOptions } from "@/app/api/auth/[...nextauth]/route"; -import replicateClient from "@/core/clients/replicate"; +import { replicate } from "@/core/clients/replicate"; import db from "@/core/db"; import { extractSeedFromLogs } from "@/core/utils/predictions"; import { getServerSession } from "next-auth"; @@ -27,9 +27,7 @@ export async function GET( where: { projectId: project.id, id: predictionId }, }); - const { data: prediction } = await replicateClient.get( - `https://api.replicate.com/v1/predictions/${shot.replicateId}` - ); + const prediction = await replicate.predictions.get(shot.replicateId); const outputUrl = prediction.output?.[0]; let blurhash = null; @@ -39,7 +37,7 @@ export async function GET( blurhash = base64; } - const seedNumber = extractSeedFromLogs(prediction.logs); + const seedNumber = extractSeedFromLogs(prediction.logs!); shot = await db.shot.update({ where: { id: shot.id }, diff --git a/src/app/api/projects/[id]/predictions/route.ts b/src/app/api/projects/[id]/predictions/route.ts index 5a5adeb..3f86886 100644 --- a/src/app/api/projects/[id]/predictions/route.ts +++ b/src/app/api/projects/[id]/predictions/route.ts @@ -1,7 +1,9 @@ import { authOptions } from "@/app/api/auth/[...nextauth]/route"; -import replicateClient from "@/core/clients/replicate"; +import openai from "@/core/clients/openai"; +import { replicate } from "@/core/clients/replicate"; import db from "@/core/db"; import { replacePromptToken } from "@/core/utils/predictions"; +import { prompts } from "@/core/utils/prompts"; import { getServerSession } from "next-auth"; import { NextResponse } from "next/server"; @@ -27,25 +29,46 @@ export async function POST( return NextResponse.json({ message: "No credit" }, { status: 400 }); } - const { data } = await replicateClient.post( - `https://api.replicate.com/v1/predictions`, - { - input: { - prompt: replacePromptToken(prompt, project), - negative_prompt: - process.env.REPLICATE_NEGATIVE_PROMPT || - "cropped face, cover face, cover visage, mutated hands", - ...(image && { image }), - ...(seed && { seed }), - }, - version: project.modelVersionId, - } - ); + const instruction = `${process.env.OPENAI_API_SEED_PROMPT} + +${prompts.slice(0, 5).map( + (style) => `${style.label}: ${style.prompt} + +` +)} + +Keyword: ${prompt} +`; + + const chatCompletion = await openai.chat.completions.create({ + messages: [{ role: "user", content: instruction }], + model: "gpt-4-turbo", + temperature: 0.5, + max_tokens: 200, + }); + + let refinedPrompt = chatCompletion.choices?.[0]?.message?.content?.trim(); + + const prediction = await replicate.predictions.create({ + model: `${process.env.REPLICATE_USERNAME}/${project.id}`, + version: project.modelVersionId!, + input: { + prompt: `${replacePromptToken( + `${refinedPrompt}. This a portrait of ${project.instanceName} @me and not another person.`, + project + )}`, + negative_prompt: + process.env.REPLICATE_NEGATIVE_PROMPT || + "cropped face, cover face, cover visage, mutated hands", + ...(image && { image }), + ...(seed && { seed }), + }, + }); const shot = await db.shot.create({ data: { prompt, - replicateId: data.id, + replicateId: prediction.id, status: "starting", projectId: project.id, }, diff --git a/src/app/api/projects/[id]/prompter/route.ts b/src/app/api/projects/[id]/prompter/route.ts index 9c5ece9..5b1d033 100644 --- a/src/app/api/projects/[id]/prompter/route.ts +++ b/src/app/api/projects/[id]/prompter/route.ts @@ -1,6 +1,7 @@ import { authOptions } from "@/app/api/auth/[...nextauth]/route"; import openai from "@/core/clients/openai"; import db from "@/core/db"; +import { prompts } from "@/core/utils/prompts"; import { getServerSession } from "next-auth"; import { NextResponse } from "next/server"; @@ -30,17 +31,24 @@ export async function POST( } try { - const completion = await openai.createCompletion({ - model: "text-davinci-003", - temperature: 0.7, - max_tokens: 256, - top_p: 1, - prompt: `${process.env.OPENAI_API_SEED_PROMPT} - -${keyword}:`, + const instruction = `${process.env.OPENAI_API_SEED_PROMPT} + +${prompts.map( + (style) => `${style.label}: ${style.prompt} + +` +)} + +Keyword: ${keyword} +`; + + const chatCompletion = await openai.chat.completions.create({ + messages: [{ role: "user", content: instruction }], + model: "gpt-4", + temperature: 0.5, }); - const prompt = completion.data.choices?.[0].text!.trim(); + const prompt = chatCompletion.choices?.[0]?.message?.content?.trim(); if (prompt) { project = await db.project.update({ @@ -56,6 +64,7 @@ ${keyword}:`, promptWizardCredits: project.promptWizardCredits, }); } catch (e) { + console.log({ e }); return NextResponse.json({ success: false }, { status: 400 }); } } diff --git a/src/app/api/projects/[id]/train/route.ts b/src/app/api/projects/[id]/train/route.ts index cb3f4f6..e5ef77e 100644 --- a/src/app/api/projects/[id]/train/route.ts +++ b/src/app/api/projects/[id]/train/route.ts @@ -1,6 +1,5 @@ -import replicateClient from "@/core/clients/replicate"; +import { replicate } from "@/core/clients/replicate"; import db from "@/core/db"; -import { getRefinedInstanceClass } from "@/core/utils/predictions"; import { getServerSession } from "next-auth"; import { NextResponse } from "next/server"; import { authOptions } from "../../../auth/[...nextauth]/route"; @@ -25,35 +24,35 @@ export async function POST( }, }); - const instanceClass = getRefinedInstanceClass(project.instanceClass); + await replicate.models.create(process.env.REPLICATE_USERNAME!, project.id, { + description: project.id, + visibility: "private", + hardware: "gpu-t4", + }); - const responseReplicate = await replicateClient.post( - "/v1/trainings", + const training = await replicate.trainings.create( + "ostris", + "flux-dev-lora-trainer", + "e440909d3512c31646ee2e0c7d6f6f4923224863a6a10c494606e79fb5844497", { + destination: `${process.env.REPLICATE_USERNAME}/${project.id}`, input: { - instance_prompt: `a photo of a ${process.env.NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN} ${instanceClass}`, - class_prompt: `a photo of a ${instanceClass}`, - instance_data: `https://${process.env.S3_UPLOAD_BUCKET}.s3.amazonaws.com/${project.id}.zip`, - max_train_steps: Number(process.env.REPLICATE_MAX_TRAIN_STEPS || 3000), - num_class_images: 200, - learning_rate: 1e-6, - }, - model: `${process.env.REPLICATE_USERNAME}/${project.id}`, - webhook_completed: `${process.env.NEXTAUTH_URL}/api/webhooks/completed`, - }, - { - headers: { - Authorization: `Token ${process.env.REPLICATE_API_TOKEN}`, - "Content-Type": "application/json", + trigger_word: process.env.NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN, + input_images: `https://${process.env.S3_UPLOAD_BUCKET}.s3.amazonaws.com/${project.id}.zip`, + //max_train_steps: Number(process.env.REPLICATE_MAX_TRAIN_STEPS || 3000), + //num_class_images: 200, + //learning_rate: 1e-6, + webhook: `${process.env.NEXTAUTH_URL}/api/webhooks/completed`, }, } ); - const replicateModelId = responseReplicate.data.id as string; - project = await db.project.update({ where: { id: project.id }, - data: { replicateModelId: replicateModelId, modelStatus: "processing" }, + data: { + replicateModelId: training.id, + modelStatus: "processing", + }, }); return NextResponse.json({ project }); diff --git a/src/app/api/projects/route.ts b/src/app/api/projects/route.ts index 04d60bf..b6bbb1a 100644 --- a/src/app/api/projects/route.ts +++ b/src/app/api/projects/route.ts @@ -1,4 +1,4 @@ -import replicateClient from "@/core/clients/replicate"; +import { replicate } from "@/core/clients/replicate"; import s3Client from "@/core/clients/s3"; import db from "@/core/db"; import { createZipFolder } from "@/core/utils/assets"; @@ -22,14 +22,23 @@ export async function GET() { for (const project of projects) { if (project?.replicateModelId && project?.modelStatus !== "succeeded") { - const { data } = await replicateClient.get( - `/v1/trainings/${project.replicateModelId}` - ); + try { + const training = await replicate.trainings.get( + project.replicateModelId + ); - await db.project.update({ - where: { id: project.id }, - data: { modelVersionId: data.version, modelStatus: data?.status }, - }); + const version = training?.output?.version?.split?.(":")?.[1]; + + await db.project.update({ + where: { id: project.id }, + data: { + modelVersionId: version, + modelStatus: training?.status, + }, + }); + } catch (error) { + console.log({ error }); + } } } @@ -55,6 +64,7 @@ export async function POST(request: Request) { instanceClass: instanceClass || "person", instanceName: process.env.NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN!, credits: Number(process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT) || 50, + version: "V2", }, }); diff --git a/src/components/home/Pricing.tsx b/src/components/home/Pricing.tsx index 225efcc..edbce67 100644 --- a/src/components/home/Pricing.tsx +++ b/src/components/home/Pricing.tsx @@ -1,4 +1,4 @@ -import React from "react"; +import { formatStudioPrice } from "@/core/utils/prices"; import { Box, List, @@ -8,8 +8,8 @@ import { Tag, Text, } from "@chakra-ui/react"; +import React from "react"; import { HiBadgeCheck } from "react-icons/hi"; -import { formatStudioPrice } from "@/core/utils/prices"; export const CheckedListItem = ({ children, @@ -77,9 +77,7 @@ const Pricing = () => { {process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT} avatars 4K generation - - 30 AI prompt assists - + AI prompt refinement Craft your own prompt Sponsorship development 🖤 diff --git a/src/components/pages/StudioPage.tsx b/src/components/pages/StudioPage.tsx index 08d1299..766916a 100644 --- a/src/components/pages/StudioPage.tsx +++ b/src/components/pages/StudioPage.tsx @@ -18,7 +18,7 @@ export interface IStudioPageProps { hasImageInputAvailable: boolean; } -const StudioPage = ({ project, hasImageInputAvailable }: IStudioPageProps) => ( +const StudioPage = ({ project }: IStudioPageProps) => ( @@ -32,7 +32,7 @@ const StudioPage = ({ project, hasImageInputAvailable }: IStudioPageProps) => ( Back to Dashboard - + diff --git a/src/components/projects/FormPayment.tsx b/src/components/projects/FormPayment.tsx index 92d46ef..124593a 100644 --- a/src/components/projects/FormPayment.tsx +++ b/src/components/projects/FormPayment.tsx @@ -81,9 +81,7 @@ const FormPayment = ({ {process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT} avatars 4K generation - - 30 AI prompt assists - + AI prompt refinement Your Studio will be deleted 24 hours after your credits are exhausted diff --git a/src/components/projects/PromptPanel.tsx b/src/components/projects/PromptPanel.tsx index ace7f3c..5e9de14 100644 --- a/src/components/projects/PromptPanel.tsx +++ b/src/components/projects/PromptPanel.tsx @@ -7,8 +7,8 @@ import { Flex, HStack, Icon, + Input, Text, - Textarea, VStack, } from "@chakra-ui/react"; import { Project, Shot } from "@prisma/client"; @@ -18,13 +18,8 @@ import { BsLightbulb } from "react-icons/bs"; import { FaCameraRetro } from "react-icons/fa"; import { useMutation } from "react-query"; import PromptsDrawer from "./PromptsDrawer"; -import PromptImage from "./PromptImage"; -const PromptPanel = ({ - hasImageInputAvailable, -}: { - hasImageInputAvailable: Boolean; -}) => { +const PromptPanel = () => { const { project, shotCredits, @@ -86,7 +81,6 @@ const PromptPanel = ({ - {hasImageInputAvailable && } -