Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frame interpolation go #3135

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
20 changes: 20 additions & 0 deletions ai/file_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,27 @@ func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.GenImageToVide
return &resp, nil
}

func (w *FileWorker) FrameInterpolation(ctx context.Context, req worker.FrameInterpolationMultipartRequestBody) (*worker.VideoResponse, error) {
fname, ok := w.files["frame-interpolation"]
if !ok {
return nil, errors.New("frame-interpolation response file not found")
}

data, err := os.ReadFile(fname)
if err != nil {
return nil, err
}

var resp worker.VideoResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}

return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {

fname, ok := w.files["upscale"]
if !ok {
return nil, errors.New("upscale response file not found")
Expand Down
17 changes: 16 additions & 1 deletion cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,21 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)
}
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)

case "frame-interpolation":
_, ok := capabilityConstraints[core.Capability_FrameInterpolation]
if !ok {
aiCaps = append(aiCaps, core.Capability_FrameInterpolation)
capabilityConstraints[core.Capability_FrameInterpolation] = &core.PerCapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_FrameInterpolation].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_FrameInterpolation, config.ModelID, autoPrice)
}

case "llm":
_, ok := capabilityConstraints[core.Capability_LLM]
Expand All @@ -1334,6 +1348,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_LLM, config.ModelID, autoPrice)
}

case "segment-anything-2":
_, ok := capabilityConstraints[core.Capability_SegmentAnything2]
if !ok {
Expand Down
3 changes: 1 addition & 2 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"math"
"math/big"
"math/rand"
Expand Down Expand Up @@ -451,7 +450,7 @@ func ReadAtMost(r io.Reader, n int) ([]byte, error) {
// Reading one extra byte to check if input Reader
// had more than n bytes
limitedReader := io.LimitReader(r, int64(n)+1)
b, err := ioutil.ReadAll(limitedReader)
b, err := io.ReadAll(limitedReader)
if err == nil && len(b) > n {
return nil, errors.New("input bigger than max buffer size")
}
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type AI interface {
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
FrameInterpolation(context.Context, worker.FrameInterpolationMultipartRequestBody) (*worker.VideoResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const (
Capability_AudioToText Capability = 31
Capability_SegmentAnything2 Capability = 32
Capability_LLM Capability = 33
Capability_FrameInterpolation Capability = 34
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -117,6 +118,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_AudioToText: "Audio to text",
Capability_SegmentAnything2: "Segment anything 2",
Capability_LLM: "Large language model",
Capability_FrameInterpolation: "Frame Interpolation",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -205,6 +207,7 @@ func OptionalCapabilities() []Capability {
Capability_TextToImage,
Capability_ImageToImage,
Capability_ImageToVideo,
Capability_FrameInterpolation,
Capability_Upscale,
Capability_AudioToText,
Capability_SegmentAnything2,
Expand Down
99 changes: 97 additions & 2 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/big"
"net/url"
"os"
Expand All @@ -29,6 +29,7 @@ import (
"github.com/livepeer/go-livepeer/net"
"github.com/livepeer/go-livepeer/pm"
"github.com/livepeer/go-tools/drivers"
ffmpeg_go "github.com/u2takey/ffmpeg-go"

lpcrypto "github.com/livepeer/go-livepeer/crypto"
lpmon "github.com/livepeer/go-livepeer/monitor"
Expand Down Expand Up @@ -126,6 +127,10 @@ func (orch *orchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMult
return orch.node.upscale(ctx, req)
}

func (orch *orchestrator) FrameInterpolation(ctx context.Context, req worker.FrameInterpolationMultipartRequestBody) (*worker.VideoResponse, error) {
return orch.node.FrameInterpolation(ctx, req)
}

func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return orch.node.AudioToText(ctx, req)
}
Expand Down Expand Up @@ -756,7 +761,7 @@ func (n *LivepeerNode) transcodeSeg(ctx context.Context, config transcodeConfig,
// Create input file from segment. Removed after claiming complete or error
fname := path.Join(n.WorkDir, inName)
fnamep = &fname
if err := ioutil.WriteFile(fname, seg.Data, 0644); err != nil {
if err := os.WriteFile(fname, seg.Data, 0644); err != nil {
clog.Errorf(ctx, "Transcoder cannot write file err=%q", err)
return terr(err)
}
Expand Down Expand Up @@ -967,6 +972,96 @@ func (n *LivepeerNode) serveTranscoder(stream net.Transcoder_RegisterTranscoderS
defer n.SetMaxSessions(n.GetCurrentCapacity())
}
}
func (n *LivepeerNode) FrameInterpolation(ctx context.Context, req worker.FrameInterpolationMultipartRequestBody) (*worker.VideoResponse, error) {
// Generate interpolated frames
start := time.Now()
resp, err := n.AIWorker.FrameInterpolation(ctx, req)
if err != nil {
return nil, err
}

took := time.Since(start)
clog.V(common.DEBUG).Infof(ctx, "Generating interpolated frames took=%v", took)

sessionID := string(RandomManifestID())
framerate := 24

video, err := req.Video.Reader()
if err != nil {
return nil, err
}
defer video.Close() // Don't forget to close the video after you're done with it

probeData, err := ffmpeg_go.ProbeReader(video)
if err != nil {
return nil, err
}

var probeDataMap map[string]interface{}
err = json.Unmarshal([]byte(probeData), &probeDataMap)
if err != nil {
return nil, err
}

streamData := probeDataMap["streams"].([]interface{})[0].(map[string]interface{})
width := streamData["width"].(float64)
height := streamData["height"].(float64)
frameRate := streamData["r_frame_rate"].(string)

// You can parse the frame rate string to extract the numerator and denominator
numerator, denominator := 0, 0
_, err = fmt.Sscanf(frameRate, "%d/%d", &numerator, &denominator)
if err != nil {
return nil, err
}

inProfile := ffmpeg.VideoProfile{
Framerate: uint(framerate),
FramerateDen: 1,
}
outProfile := ffmpeg.VideoProfile{
Name: "frame-interpolation",
Resolution: fmt.Sprintf("%vx%v", width, height),
// hardcoded to 24fps but later can be modified
// such that we can provide the option to choose.
Framerate: 24,
Bitrate: "6000k",
Format: ffmpeg.FormatMP4,
}

// Transcode frames into segments.
videos := make([][]worker.Media, len(resp.Frames))
for i, batch := range resp.Frames {
videos[i] = make([]worker.Media, 1)
// Create slice of frame urls for a batch
urls := make([]string, len(batch))
for j, frame := range batch {
urls[j] = frame.Url
}
// Transcode slice of frame urls into a segment
res := n.transcodeFrames(ctx, sessionID, urls, inProfile, outProfile)
if res.Err != nil {
return nil, res.Err
}
// Assume only single rendition right now
seg := res.TranscodeData.Segments[0]
name := fmt.Sprintf("%v.mp4", RandomManifestID())
segData := bytes.NewReader(seg.Data)
uri, err := res.OS.SaveData(ctx, name, segData, nil, 0)
if err != nil {
return nil, err
}
videos[i][0] = worker.Media{
Url: uri,
}
// NOTE: Seed is consistent for video; NSFW check applies to first frame only.
if len(batch) > 0 {
videos[i][0].Nsfw = batch[0].Nsfw
videos[i][0].Seed = batch[0].Seed
}
}
return &worker.VideoResponse{Frames: videos}, nil
}

func (n *LivepeerNode) textToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.TextToImage(ctx, req)
Expand Down
Loading