Skip to content

Commit

Permalink
add upscale image support using stabilityai/stable-diffusion-x4-upsca…
Browse files Browse the repository at this point in the history
…ler model
  • Loading branch information
mikezupper committed Jun 5, 2024
1 parent aa8ae45 commit 1b9dc5a
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 1 deletion.
19 changes: 19 additions & 0 deletions ai/file_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,25 @@ func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMu
return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.UpscaleImageMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["upscale"]
if !ok {
return nil, errors.New("upscale response file not found")
}

data, err := os.ReadFile(fname)
if err != nil {
return nil, err
}

var resp worker.ImageResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}

return &resp, nil
}

func (w *FileWorker) Warm(ctx context.Context, containerName, modelID string) error {
return nil
}
Expand Down
12 changes: 12 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,18 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
constraints[core.Capability_ImageToImage].Models[config.ModelID] = modelConstraint

n.SetBasePriceForCap("default", core.Capability_ImageToImage, config.ModelID, big.NewRat(config.PricePerUnit, config.PixelsPerUnit))
case "upscale":
_, ok := constraints[core.Capability_Upscale]
if !ok {
aiCaps = append(aiCaps, core.Capability_Upscale)
constraints[core.Capability_Upscale] = &core.Constraints{
Models: make(map[string]*core.ModelConstraint),
}
}

constraints[core.Capability_Upscale].Models[config.ModelID] = modelConstraint

n.SetBasePriceForCap("default", core.Capability_Upscale, config.ModelID, big.NewRat(config.PricePerUnit, config.PixelsPerUnit))
case "image-to-video":
_, ok := constraints[core.Capability_ImageToVideo]
if !ok {
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
type AI interface {
TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
Upscale(context.Context, worker.UpscaleImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ const (
Capability_TextToImage
Capability_ImageToImage
Capability_ImageToVideo
Capability_Upscale
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -104,6 +105,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_TextToImage: "Text to image",
Capability_ImageToImage: "Image to image",
Capability_ImageToVideo: "Image to video",
Capability_Upscale: "Upscale",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -192,6 +194,7 @@ func OptionalCapabilities() []Capability {
Capability_TextToImage,
Capability_ImageToImage,
Capability_ImageToVideo,
Capability_Upscale,
}
}

Expand Down
7 changes: 6 additions & 1 deletion core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ func (orch *orchestrator) TextToImage(ctx context.Context, req worker.TextToImag
func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToImage(ctx, req)
}

func (orch *orchestrator) Upscale(ctx context.Context, req worker.UpscaleImageMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.upscale(ctx, req)
}
func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToVideo(ctx, req)
}
Expand Down Expand Up @@ -936,6 +938,9 @@ func (n *LivepeerNode) textToImage(ctx context.Context, req worker.TextToImageJS
func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.ImageToImage(ctx, req)
}
func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleImageMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.Upscale(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
Expand Down
43 changes: 43 additions & 0 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/text-to-image", oapiReqValidator(lp.TextToImage()))
lp.transRPC.Handle("/image-to-image", oapiReqValidator(lp.ImageToImage()))
lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo()))
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))

return nil
}
Expand Down Expand Up @@ -108,6 +109,29 @@ func (h *lphttp) ImageToVideo() http.Handler {
})
}

func (h *lphttp) Upscale() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.UpscaleImageMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -156,6 +180,25 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return orch.ImageToImage(ctx, v)
}

imageRdr, err := v.Image.Reader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}
config, _, err := image.DecodeConfig(imageRdr)
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.UpscaleImageMultipartRequestBody:
pipeline = "upscale"
cap = core.Capability_Upscale
modelID = *v.ModelId
submitFn = func(ctx context.Context) (*worker.ImageResponse, error) {
return orch.Upscale(ctx, v)
}

imageRdr, err := v.Image.Reader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
Expand Down
49 changes: 49 additions & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func startAIMediaServer(ls *LivepeerServer) error {

ls.HTTPMux.Handle("/text-to-image", oapiReqValidator(ls.TextToImage()))
ls.HTTPMux.Handle("/image-to-image", oapiReqValidator(ls.ImageToImage()))
ls.HTTPMux.Handle("/upscale", oapiReqValidator(ls.Upscale()))
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())

Expand Down Expand Up @@ -161,6 +162,54 @@ func (ls *LivepeerServer) ImageToImage() http.Handler {
})
}

func (ls *LivepeerServer) Upscale() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.UpscaleImageMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received Upscale request imageSize=%v prompt=%v model_id=%v", req.Image.FileSize(), req.Prompt, *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(string(core.RandomManifestID())),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processUpscale(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
if errors.As(err, &e) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed Upscale request imageSize=%v prompt=%v model_id=%v took=%v", req.Image.FileSize(), req.Prompt, *req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}

func (ls *LivepeerServer) ImageToVideo() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
90 changes: 90 additions & 0 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const processingRetryTimeout = 2 * time.Second
const defaultTextToImageModelID = "stabilityai/sdxl-turbo"
const defaultImageToImageModelID = "stabilityai/sdxl-turbo"
const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt"
const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler"

type ServiceUnavailableError struct {
err error
Expand Down Expand Up @@ -200,6 +201,86 @@ func submitImageToImage(ctx context.Context, params aiRequestParams, sess *AISes
return resp.JSON200, nil
}

func processUpscale(ctx context.Context, params aiRequestParams, req worker.UpscaleImageMultipartRequestBody) (*worker.ImageResponse, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
}

newMedia := make([]worker.Media, len(resp.Images))
for i, media := range resp.Images {
var data bytes.Buffer
writer := bufio.NewWriter(&data)
if err := worker.ReadImageB64DataUrl(media.Url, writer); err != nil {
return nil, err
}
writer.Flush()

name := string(core.RandomManifestID()) + ".png"
newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(data.Bytes()), nil, 0)
if err != nil {
return nil, err
}

newMedia[i] = worker.Media{Nsfw: media.Nsfw, Seed: media.Seed, Url: newUrl}
}

resp.Images = newMedia

return resp, nil
}

func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession, req worker.UpscaleImageMultipartRequestBody) (*worker.ImageResponse, error) {
var buf bytes.Buffer
mw, err := worker.NewUpscaleMultipartWriter(&buf, req)
if err != nil {
return nil, err
}

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
return nil, err
}

imageRdr, err := req.Image.Reader()
if err != nil {
return nil, err
}
config, _, err := image.DecodeConfig(imageRdr)
if err != nil {
return nil, err
}
outPixels := int64(config.Height) * int64(config.Width)

setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outPixels)
if err != nil {
return nil, err
}
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)

start := time.Now()
resp, err := client.UpscaleImageWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders)
took := time.Since(start)
if err != nil {
return nil, err
}

if resp.JSON200 == nil {
// TODO: Replace trim newline with better error spec from O
return nil, errors.New(strings.TrimSuffix(string(resp.Body), "\n"))
}

// We treat a response as "receiving change" where the change is the difference between the credit and debit for the update
if balUpdate != nil {
balUpdate.Status = ReceivedChange
}

// TODO: Refine this rough estimate in future iterations
sess.LatencyScore = took.Seconds() / float64(outPixels)

return resp.JSON200, nil
}

func processImageToVideo(ctx context.Context, params aiRequestParams, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
Expand Down Expand Up @@ -328,6 +409,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (*worker.ImageResponse, error) {
return submitImageToVideo(ctx, params, sess, v)
}
case worker.UpscaleImageMultipartRequestBody:
cap = core.Capability_Upscale
modelID = defaultUpscaleModelID
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (*worker.ImageResponse, error) {
return submitUpscale(ctx, params, sess, v)
}
default:
return nil, errors.New("unknown AI request type")
}
Expand Down
1 change: 1 addition & 0 deletions server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type Orchestrator interface {
TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error)
Upscale(ctx context.Context, req worker.UpscaleImageMultipartRequestBody) (*worker.ImageResponse, error)
}

// Balance describes methods for a session's balance maintenance
Expand Down

0 comments on commit 1b9dc5a

Please sign in to comment.