Skip to content

Commit

Permalink
Fix api interaction
Browse files Browse the repository at this point in the history
  • Loading branch information
Keavon authored and TrueDoctor committed May 20, 2024
1 parent 1c8ddd2 commit d1936e6
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 109 deletions.
109 changes: 0 additions & 109 deletions node-graph/gstd/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,112 +16,3 @@ pub struct PostNode<Body> {
async fn post_node(url: String, body: String) -> reqwest::Response {
reqwest::Client::new().post(url).body(body).send().await.unwrap()
}

#[cfg(feature = "serde")]
async fn image_to_image(image: ImageFrame<SRGBA8>, prompt: String) -> reqwest::Result<ImageFrame<SRGBA8>> {
let png_bytes = image.image.to_png();
// let base64 = base64::encode(png_bytes);
// post to cloudflare image to image endpoint using reqwest
let payload = PayloadBuilder::new().guidance(7.5).image(png_bytes.to_vec()).num_steps(20).prompt(prompt).strength(1).build();

let client = Client::new();
let account_id = "023e105f4ecef8ad9ca31a8372d0c353";
let response = client
.post(format!("https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/bytedance/stable-diffusion-xl-lightning"))
.json(&payload)
.header("Content-Type", "application/json")
.header("Authorization", "Bearer 123")
.send()
.await?;

let text = response.text().await?;
println!("{}", text);

Ok(image)
}
use reqwest::Client;
use serde::Serialize;

#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Default)]
struct PayloadBuilder {
guidance: Option<f64>,
image: Option<Vec<u8>>,
mask: Option<Vec<u32>>,
num_steps: Option<u32>,
prompt: Option<String>,
strength: Option<u32>,
}

impl PayloadBuilder {
fn new() -> Self {
Self::default()
}

fn guidance(mut self, value: f64) -> Self {
self.guidance = Some(value);
self
}

fn image(mut self, value: Vec<u8>) -> Self {
self.image = Some(value);
self
}

fn mask(mut self, value: Vec<u32>) -> Self {
self.mask = Some(value);
self
}

fn num_steps(mut self, value: u32) -> Self {
self.num_steps = Some(value);
self
}

fn prompt(mut self, value: String) -> Self {
self.prompt = Some(value);
self
}

fn strength(mut self, value: u32) -> Self {
self.strength = Some(value);
self
}

fn build(self) -> Payload {
Payload {
guidance: self.guidance.unwrap_or_default(),
image: self.image.unwrap_or_default(),
mask: self.mask.unwrap_or_default(),
num_steps: self.num_steps.unwrap_or_default(),
prompt: self.prompt.unwrap_or_default(),
strength: self.strength.unwrap_or_default(),
}
}
}

#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct Payload {
guidance: f64,
image: Vec<u8>,
mask: Vec<u32>,
num_steps: u32,
prompt: String,
strength: u32,
}

#[cfg(test)]
mod test {
use super::*;
use graphene_core::{raster::Image, Color};
#[tokio::test]
async fn test_cloudflare() {
let test_image = ImageFrame {
image: Image::new(100, 100, SRGBA8::from(Color::RED)),
..Default::default()
};
let result = image_to_image(test_image, "make green".into()).await;
dbg!(result);
panic!("show result");
}
}
138 changes: 138 additions & 0 deletions node-graph/gstd/src/imaginate_v2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
use graphene_core::raster::{Image, ImageFrame, Pixel, SRGBA8};

use crate::Node;

async fn image_to_image(image_frame: ImageFrame<SRGBA8>, prompt: String) -> reqwest::Result<ImageFrame<SRGBA8>> {
let png_bytes = image_frame.image.to_png();
//let base64 = base64::encode(png_bytes);
// post to cloudflare image to image endpoint using reqwest
let payload = PayloadBuilder::new()
.guidance(7.5)
.image(png_bytes.to_vec())
//.mask(png_bytes.to_vec())
.num_steps(20)
.prompt(prompt)
.strength(1);

let client = Client::new();
let account_id = "xxx";
let api_key = "123";
let request = client
//.post(format!("https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/bytedance/stable-diffusion-xl-base-1.0"))
//.post(format!("https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/stabilityai/stable-diffusion-xl-base-1.0"))
/*.post(format!(
"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/runwayml/stable-diffusion-v1-5-inpainting"
))*/
.post(format!("https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/runwayml/stable-diffusion-v1-5-img2img"))
.json(&payload)
.header("Authorization", format!("Bearer {api_key}"));
//println!("{}", serde_json::to_string(&payload).unwrap());
let response = dbg!(request).send().await?;

#[derive(Debug, serde::Deserialize)]
struct Response {
result: String,
success: bool,
};

match response.error_for_status_ref() {
Ok(_) => (),
Err(_) => panic!("{}", response.text().await?),
}
//let text: Response = response.json().await?;
/*let text = response.text().await?;
let text = Response {
result: serde_json::Value::String(text),
success: false,
};
dbg!(&text);*/

let bytes = response.bytes().await?;
//let bytes = &[];

let image = image::load_from_memory_with_format(&bytes[..], image::ImageFormat::Png).unwrap();
let width = image.width();
let height = image.height();
let image = image.to_rgba8();
let data = image.as_raw();
let color_data = bytemuck::cast_slice(data).to_owned();
let image = Image {
width,
height,
data: color_data,
base64_string: None,
};

let image_frame = ImageFrame { image, ..image_frame };
Ok(image_frame)
}
use reqwest::Client;
use serde::Serialize;

#[derive(Default, Serialize)]
struct PayloadBuilder {
#[serde(skip_serializing_if = "Option::is_none")]
guidance: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
image: Option<Vec<u8>>,
#[serde(skip_serializing_if = "Option::is_none")]
mask: Option<Vec<u8>>,
#[serde(skip_serializing_if = "Option::is_none")]
num_steps: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
strength: Option<u32>,
}

impl PayloadBuilder {
fn new() -> Self {
Self::default()
}

fn guidance(mut self, value: f64) -> Self {
self.guidance = Some(value);
self
}

fn image(mut self, value: Vec<u8>) -> Self {
self.image = Some(value);
self
}

fn mask(mut self, value: Vec<u8>) -> Self {
self.mask = Some(value);
self
}

fn num_steps(mut self, value: u32) -> Self {
self.num_steps = Some(value);
self
}

fn prompt(mut self, value: String) -> Self {
self.prompt = Some(value);
self
}

fn strength(mut self, value: u32) -> Self {
self.strength = Some(value);
self
}
}

#[cfg(test)]
mod test {
use super::*;
use graphene_core::{raster::Image, Color};
#[tokio::test]
async fn test_cloudflare() {
let test_image = ImageFrame {
image: Image::new(1024, 1024, SRGBA8::from(Color::RED)),
..Default::default()
};
let result = image_to_image(test_image, "make green".into()).await;
dbg!(result.unwrap());
panic!("show result");
}
}
3 changes: 3 additions & 0 deletions node-graph/gstd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pub mod raster;

pub mod http;

#[cfg(feature = "serde")]
pub mod imaginate_v2;

pub mod any;

#[cfg(feature = "gpu")]
Expand Down

0 comments on commit d1936e6

Please sign in to comment.