Skip to content

Commit

Permalink
feat(ai): add upscaling pipeline (#3077)
Browse files Browse the repository at this point in the history
* add upscale image support using stabilityai/stable-diffusion-x4-upscaler model

* fix(ai): fix ai-worker client bindings

This commit ensures that the right golang client bindings response and
request types are used. It also cleans up the codebase a bit.

---------

Co-authored-by: Mike Zupper <[email protected]>
  • Loading branch information
rickstaa and mikezupper authored Jun 10, 2024
1 parent 0d51a78 commit b2933cf
Show file tree
Hide file tree
Showing 11 changed files with 229 additions and 3 deletions.
19 changes: 19 additions & 0 deletions ai/file_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,25 @@ func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMu
return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["upscale"]
if !ok {
return nil, errors.New("upscale response file not found")
}

data, err := os.ReadFile(fname)
if err != nil {
return nil, err
}

var resp worker.ImageResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}

return &resp, nil
}

func (w *FileWorker) Warm(ctx context.Context, containerName, modelID string) error {
return nil
}
Expand Down
12 changes: 12 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,18 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
constraints[core.Capability_ImageToVideo].Models[config.ModelID] = modelConstraint

n.SetBasePriceForCap("default", core.Capability_ImageToVideo, config.ModelID, big.NewRat(config.PricePerUnit, config.PixelsPerUnit))
case "upscale":
_, ok := constraints[core.Capability_Upscale]
if !ok {
aiCaps = append(aiCaps, core.Capability_Upscale)
constraints[core.Capability_Upscale] = &core.Constraints{
Models: make(map[string]*core.ModelConstraint),
}
}

constraints[core.Capability_Upscale].Models[config.ModelID] = modelConstraint

n.SetBasePriceForCap("default", core.Capability_Upscale, config.ModelID, big.NewRat(config.PricePerUnit, config.PixelsPerUnit))
}

if len(aiCaps) > 0 {
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type AI interface {
TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ const (
Capability_TextToImage
Capability_ImageToImage
Capability_ImageToVideo
Capability_Upscale
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -104,6 +105,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_TextToImage: "Text to image",
Capability_ImageToImage: "Image to image",
Capability_ImageToVideo: "Image to video",
Capability_Upscale: "Upscale",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -192,6 +194,7 @@ func OptionalCapabilities() []Capability {
Capability_TextToImage,
Capability_ImageToImage,
Capability_ImageToVideo,
Capability_Upscale,
}
}

Expand Down
8 changes: 8 additions & 0 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVi
return orch.node.imageToVideo(ctx, req)
}

func (orch *orchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.upscale(ctx, req)
}

func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error {
if orch.node == nil || orch.node.Recipient == nil {
return nil
Expand Down Expand Up @@ -937,6 +941,10 @@ func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImage
return n.AIWorker.ImageToImage(ctx, req)
}

func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.Upscale(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
numVideos := 1
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/golang/protobuf v1.5.3
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.0.7
github.com/livepeer/ai-worker v0.0.8
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240120150405-de94555cdc69
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.0.7 h1:mctm5jswdlMLUjaHeMhv3u4QtOAZ75jJhxgUdlzY5dU=
github.com/livepeer/ai-worker v0.0.7/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE=
github.com/livepeer/ai-worker v0.0.8 h1:FAjYJgSOaZslA06Wb6MolYohI30IMIujDTB26nfw8YE=
github.com/livepeer/ai-worker v0.0.8/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
43 changes: 43 additions & 0 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/text-to-image", oapiReqValidator(lp.TextToImage()))
lp.transRPC.Handle("/image-to-image", oapiReqValidator(lp.ImageToImage()))
lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo()))
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))

return nil
}
Expand Down Expand Up @@ -108,6 +109,29 @@ func (h *lphttp) ImageToVideo() http.Handler {
})
}

func (h *lphttp) Upscale() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.UpscaleMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -156,6 +180,25 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return orch.ImageToImage(ctx, v)
}

imageRdr, err := v.Image.Reader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}
config, _, err := image.DecodeConfig(imageRdr)
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.UpscaleMultipartRequestBody:
pipeline = "upscale"
cap = core.Capability_Upscale
modelID = *v.ModelId
submitFn = func(ctx context.Context) (*worker.ImageResponse, error) {
return orch.Upscale(ctx, v)
}

imageRdr, err := v.Image.Reader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
Expand Down
49 changes: 49 additions & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func startAIMediaServer(ls *LivepeerServer) error {

ls.HTTPMux.Handle("/text-to-image", oapiReqValidator(ls.TextToImage()))
ls.HTTPMux.Handle("/image-to-image", oapiReqValidator(ls.ImageToImage()))
ls.HTTPMux.Handle("/upscale", oapiReqValidator(ls.Upscale()))
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())

Expand Down Expand Up @@ -271,6 +272,54 @@ func (ls *LivepeerServer) ImageToVideo() http.Handler {
})
}

func (ls *LivepeerServer) Upscale() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.UpscaleMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received Upscale request imageSize=%v prompt=%v model_id=%v", req.Image.FileSize(), req.Prompt, *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(string(core.RandomManifestID())),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processUpscale(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
if errors.As(err, &e) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed Upscale request imageSize=%v prompt=%v model_id=%v took=%v", req.Image.FileSize(), req.Prompt, *req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}

func (ls *LivepeerServer) ImageToVideoResult() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
90 changes: 90 additions & 0 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const processingRetryTimeout = 2 * time.Second
const defaultTextToImageModelID = "stabilityai/sdxl-turbo"
const defaultImageToImageModelID = "stabilityai/sdxl-turbo"
const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt"
const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler"

type ServiceUnavailableError struct {
err error
Expand Down Expand Up @@ -301,6 +302,86 @@ func submitImageToVideo(ctx context.Context, params aiRequestParams, sess *AISes
return &res, nil
}

func processUpscale(ctx context.Context, params aiRequestParams, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
}

newMedia := make([]worker.Media, len(resp.Images))
for i, media := range resp.Images {
var data bytes.Buffer
writer := bufio.NewWriter(&data)
if err := worker.ReadImageB64DataUrl(media.Url, writer); err != nil {
return nil, err
}
writer.Flush()

name := string(core.RandomManifestID()) + ".png"
newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(data.Bytes()), nil, 0)
if err != nil {
return nil, err
}

newMedia[i] = worker.Media{Nsfw: media.Nsfw, Seed: media.Seed, Url: newUrl}
}

resp.Images = newMedia

return resp, nil
}

func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
var buf bytes.Buffer
mw, err := worker.NewUpscaleMultipartWriter(&buf, req)
if err != nil {
return nil, err
}

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
return nil, err
}

imageRdr, err := req.Image.Reader()
if err != nil {
return nil, err
}
config, _, err := image.DecodeConfig(imageRdr)
if err != nil {
return nil, err
}
outPixels := int64(config.Height) * int64(config.Width)

setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outPixels)
if err != nil {
return nil, err
}
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)

start := time.Now()
resp, err := client.UpscaleWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders)
took := time.Since(start)
if err != nil {
return nil, err
}

if resp.JSON200 == nil {
// TODO: Replace trim newline with better error spec from O
return nil, errors.New(strings.TrimSuffix(string(resp.Body), "\n"))
}

// We treat a response as "receiving change" where the change is the difference between the credit and debit for the update
if balUpdate != nil {
balUpdate.Status = ReceivedChange
}

// TODO: Refine this rough estimate in future iterations
sess.LatencyScore = took.Seconds() / float64(outPixels)

return resp.JSON200, nil
}

func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (*worker.ImageResponse, error) {
var cap core.Capability
var modelID string
Expand Down Expand Up @@ -334,6 +415,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (*worker.ImageResponse, error) {
return submitImageToVideo(ctx, params, sess, v)
}
case worker.UpscaleMultipartRequestBody:
cap = core.Capability_Upscale
modelID = defaultUpscaleModelID
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (*worker.ImageResponse, error) {
return submitUpscale(ctx, params, sess, v)
}
default:
return nil, errors.New("unknown AI request type")
}
Expand Down
1 change: 1 addition & 0 deletions server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type Orchestrator interface {
TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error)
Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error)
}

// Balance describes methods for a session's balance maintenance
Expand Down

0 comments on commit b2933cf

Please sign in to comment.