Skip to content

Commit

Permalink
feat: switch to flux (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
baptadn authored Dec 7, 2024
1 parent d095697 commit 2d9d597
Show file tree
Hide file tree
Showing 23 changed files with 262 additions and 170 deletions.
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"next-auth": "^4.24.3",
"next-s3-upload": "^0.3.3",
"nodemailer": "^6.9.0",
"openai": "^3.1.0",
"openai": "^4.75.0",
"plaiceholder": "^2.5.0",
"react": "18.2.0",
"react-advanced-cropper": "^0.17.0",
Expand All @@ -51,6 +51,7 @@
"react-medium-image-zoom": "^5.1.8",
"react-parallax-tilt": "^1.7.77",
"react-query": "^3.39.2",
"replicate": "^1.0.1",
"sharp": "^0.31.2",
"smartcrop-sharp": "^2.0.6",
"stripe": "^11.1.0",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- CreateEnum
CREATE TYPE "ProjectVersion" AS ENUM ('V1', 'V2');

-- AlterTable
ALTER TABLE "Project" ADD COLUMN "version" "ProjectVersion" NOT NULL DEFAULT 'V1';
6 changes: 6 additions & 0 deletions prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ model Project {
credits Int @default(100)
promptWizardCredits Int @default(20)
Payment Payment[]
version ProjectVersion @default(V1)
}

model Shot {
Expand Down Expand Up @@ -115,3 +116,8 @@ model Payment {
@@map("payments")
}

enum ProjectVersion {
V1
V2
}
16 changes: 1 addition & 15 deletions src/app/(auth)/studio/[id]/page.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import StudioPage from "@/components/pages/StudioPage";
import replicateClient from "@/core/clients/replicate";
import db from "@/core/db";
import { getCurrentSessionRedirect } from "@/lib/sessions";
import { Metadata } from "next";
Expand Down Expand Up @@ -38,20 +37,7 @@ const Studio = async ({ params }: { params: { id: string } }) => {
notFound();
}

const { data: model } = await replicateClient.get(
`https://api.replicate.com/v1/models/${process.env.REPLICATE_USERNAME}/${project.id}/versions/${project.modelVersionId}`
);

const hasImageInputAvailable = Boolean(
model.openapi_schema?.components?.schemas?.Input?.properties?.image?.title
);

return (
<StudioPage
project={project}
hasImageInputAvailable={hasImageInputAvailable}
/>
);
return <StudioPage project={project} />;
};

export default Studio;
29 changes: 12 additions & 17 deletions src/app/api/projects/[id]/predictions/[predictionId]/hd/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { authOptions } from "@/app/api/auth/[...nextauth]/route";
import replicateClient from "@/core/clients/replicate";
import { replicate } from "@/core/clients/replicate";
import db from "@/core/db";
import { getServerSession } from "next-auth";
import { NextResponse } from "next/server";
Expand Down Expand Up @@ -32,9 +32,7 @@ export async function GET(
);
}

const { data: prediction } = await replicateClient.get(
`https://api.replicate.com/v1/predictions/${shot.hdPredictionId}`
);
const prediction = await replicate.predictions.get(shot.hdPredictionId!);

if (prediction.output) {
shot = await db.shot.update({
Expand Down Expand Up @@ -77,22 +75,19 @@ export async function POST(
);
}

const { data } = await replicateClient.post(
`https://api.replicate.com/v1/predictions`,
{
input: {
image: shot.outputUrl,
upscale: 8,
face_upsample: true,
codeformer_fidelity: 1,
},
version: process.env.REPLICATE_HD_VERSION_MODEL_ID,
}
);
const prediction = await replicate.predictions.create({
version: process.env.REPLICATE_HD_VERSION_MODEL_ID!,
input: {
image: shot.outputUrl,
upscale: 8,
face_upsample: true,
codeformer_fidelity: 1,
},
});

shot = await db.shot.update({
where: { id: shot.id },
data: { hdStatus: "PENDING", hdPredictionId: data.id },
data: { hdStatus: "PENDING", hdPredictionId: prediction.id },
});

return NextResponse.json({ shot });
Expand Down
8 changes: 3 additions & 5 deletions src/app/api/projects/[id]/predictions/[predictionId]/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { authOptions } from "@/app/api/auth/[...nextauth]/route";
import replicateClient from "@/core/clients/replicate";
import { replicate } from "@/core/clients/replicate";
import db from "@/core/db";
import { extractSeedFromLogs } from "@/core/utils/predictions";
import { getServerSession } from "next-auth";
Expand Down Expand Up @@ -27,9 +27,7 @@ export async function GET(
where: { projectId: project.id, id: predictionId },
});

const { data: prediction } = await replicateClient.get(
`https://api.replicate.com/v1/predictions/${shot.replicateId}`
);
const prediction = await replicate.predictions.get(shot.replicateId);

const outputUrl = prediction.output?.[0];
let blurhash = null;
Expand All @@ -39,7 +37,7 @@ export async function GET(
blurhash = base64;
}

const seedNumber = extractSeedFromLogs(prediction.logs);
const seedNumber = extractSeedFromLogs(prediction.logs!);

shot = await db.shot.update({
where: { id: shot.id },
Expand Down
55 changes: 39 additions & 16 deletions src/app/api/projects/[id]/predictions/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { authOptions } from "@/app/api/auth/[...nextauth]/route";
import replicateClient from "@/core/clients/replicate";
import openai from "@/core/clients/openai";
import { replicate } from "@/core/clients/replicate";
import db from "@/core/db";
import { replacePromptToken } from "@/core/utils/predictions";
import { prompts } from "@/core/utils/prompts";
import { getServerSession } from "next-auth";
import { NextResponse } from "next/server";

Expand All @@ -27,25 +29,46 @@ export async function POST(
return NextResponse.json({ message: "No credit" }, { status: 400 });
}

const { data } = await replicateClient.post(
`https://api.replicate.com/v1/predictions`,
{
input: {
prompt: replacePromptToken(prompt, project),
negative_prompt:
process.env.REPLICATE_NEGATIVE_PROMPT ||
"cropped face, cover face, cover visage, mutated hands",
...(image && { image }),
...(seed && { seed }),
},
version: project.modelVersionId,
}
);
const instruction = `${process.env.OPENAI_API_SEED_PROMPT}
${prompts.slice(0, 5).map(
(style) => `${style.label}: ${style.prompt}
`
)}
Keyword: ${prompt}
`;

const chatCompletion = await openai.chat.completions.create({
messages: [{ role: "user", content: instruction }],
model: "gpt-4-turbo",
temperature: 0.5,
max_tokens: 200,
});

let refinedPrompt = chatCompletion.choices?.[0]?.message?.content?.trim();

const prediction = await replicate.predictions.create({
model: `${process.env.REPLICATE_USERNAME}/${project.id}`,
version: project.modelVersionId!,
input: {
prompt: `${replacePromptToken(
`${refinedPrompt}. This a portrait of ${project.instanceName} @me and not another person.`,
project
)}`,
negative_prompt:
process.env.REPLICATE_NEGATIVE_PROMPT ||
"cropped face, cover face, cover visage, mutated hands",
...(image && { image }),
...(seed && { seed }),
},
});

const shot = await db.shot.create({
data: {
prompt,
replicateId: data.id,
replicateId: prediction.id,
status: "starting",
projectId: project.id,
},
Expand Down
27 changes: 18 additions & 9 deletions src/app/api/projects/[id]/prompter/route.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { authOptions } from "@/app/api/auth/[...nextauth]/route";
import openai from "@/core/clients/openai";
import db from "@/core/db";
import { prompts } from "@/core/utils/prompts";
import { getServerSession } from "next-auth";
import { NextResponse } from "next/server";

Expand Down Expand Up @@ -30,17 +31,24 @@ export async function POST(
}

try {
const completion = await openai.createCompletion({
model: "text-davinci-003",
temperature: 0.7,
max_tokens: 256,
top_p: 1,
prompt: `${process.env.OPENAI_API_SEED_PROMPT}
${keyword}:`,
const instruction = `${process.env.OPENAI_API_SEED_PROMPT}
${prompts.map(
(style) => `${style.label}: ${style.prompt}
`
)}
Keyword: ${keyword}
`;

const chatCompletion = await openai.chat.completions.create({
messages: [{ role: "user", content: instruction }],
model: "gpt-4",
temperature: 0.5,
});

const prompt = completion.data.choices?.[0].text!.trim();
const prompt = chatCompletion.choices?.[0]?.message?.content?.trim();

if (prompt) {
project = await db.project.update({
Expand All @@ -56,6 +64,7 @@ ${keyword}:`,
promptWizardCredits: project.promptWizardCredits,
});
} catch (e) {
console.log({ e });
return NextResponse.json({ success: false }, { status: 400 });
}
}
43 changes: 21 additions & 22 deletions src/app/api/projects/[id]/train/route.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import replicateClient from "@/core/clients/replicate";
import { replicate } from "@/core/clients/replicate";
import db from "@/core/db";
import { getRefinedInstanceClass } from "@/core/utils/predictions";
import { getServerSession } from "next-auth";
import { NextResponse } from "next/server";
import { authOptions } from "../../../auth/[...nextauth]/route";
Expand All @@ -25,35 +24,35 @@ export async function POST(
},
});

const instanceClass = getRefinedInstanceClass(project.instanceClass);
await replicate.models.create(process.env.REPLICATE_USERNAME!, project.id, {
description: project.id,
visibility: "private",
hardware: "gpu-t4",
});

const responseReplicate = await replicateClient.post(
"/v1/trainings",
const training = await replicate.trainings.create(
"ostris",
"flux-dev-lora-trainer",
"e440909d3512c31646ee2e0c7d6f6f4923224863a6a10c494606e79fb5844497",
{
destination: `${process.env.REPLICATE_USERNAME}/${project.id}`,
input: {
instance_prompt: `a photo of a ${process.env.NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN} ${instanceClass}`,
class_prompt: `a photo of a ${instanceClass}`,
instance_data: `https://${process.env.S3_UPLOAD_BUCKET}.s3.amazonaws.com/${project.id}.zip`,
max_train_steps: Number(process.env.REPLICATE_MAX_TRAIN_STEPS || 3000),
num_class_images: 200,
learning_rate: 1e-6,
},
model: `${process.env.REPLICATE_USERNAME}/${project.id}`,
webhook_completed: `${process.env.NEXTAUTH_URL}/api/webhooks/completed`,
},
{
headers: {
Authorization: `Token ${process.env.REPLICATE_API_TOKEN}`,
"Content-Type": "application/json",
trigger_word: process.env.NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN,
input_images: `https://${process.env.S3_UPLOAD_BUCKET}.s3.amazonaws.com/${project.id}.zip`,
//max_train_steps: Number(process.env.REPLICATE_MAX_TRAIN_STEPS || 3000),
//num_class_images: 200,
//learning_rate: 1e-6,
webhook: `${process.env.NEXTAUTH_URL}/api/webhooks/completed`,
},
}
);

const replicateModelId = responseReplicate.data.id as string;

project = await db.project.update({
where: { id: project.id },
data: { replicateModelId: replicateModelId, modelStatus: "processing" },
data: {
replicateModelId: training.id,
modelStatus: "processing",
},
});

return NextResponse.json({ project });
Expand Down
26 changes: 18 additions & 8 deletions src/app/api/projects/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import replicateClient from "@/core/clients/replicate";
import { replicate } from "@/core/clients/replicate";
import s3Client from "@/core/clients/s3";
import db from "@/core/db";
import { createZipFolder } from "@/core/utils/assets";
Expand All @@ -22,14 +22,23 @@ export async function GET() {

for (const project of projects) {
if (project?.replicateModelId && project?.modelStatus !== "succeeded") {
const { data } = await replicateClient.get(
`/v1/trainings/${project.replicateModelId}`
);
try {
const training = await replicate.trainings.get(
project.replicateModelId
);

await db.project.update({
where: { id: project.id },
data: { modelVersionId: data.version, modelStatus: data?.status },
});
const version = training?.output?.version?.split?.(":")?.[1];

await db.project.update({
where: { id: project.id },
data: {
modelVersionId: version,
modelStatus: training?.status,
},
});
} catch (error) {
console.log({ error });
}
}
}

Expand All @@ -55,6 +64,7 @@ export async function POST(request: Request) {
instanceClass: instanceClass || "person",
instanceName: process.env.NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN!,
credits: Number(process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT) || 50,
version: "V2",
},
});

Expand Down
Loading

0 comments on commit 2d9d597

Please sign in to comment.