Skip to content

Commit

Permalink
Add support for tags
Browse files Browse the repository at this point in the history
  • Loading branch information
playfulkittykat committed Oct 20, 2024
1 parent 660c1b7 commit 93f7f5d
Show file tree
Hide file tree
Showing 6 changed files with 539 additions and 25 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ itertools = "0.10"
futures = { version = "0.3", default-features = false }
reqwest = { version = "0.11", default-features = false, features = ["json"] }
tokio = { optional = true, version = "1" }
serde_with = "3.11.0"
serde_repr = "0.1.19"

[dev-dependencies]
mockito = "0.30"
Expand Down
27 changes: 27 additions & 0 deletions examples/tags_search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use futures::prelude::*;
use rs621::{
client::Client,
tag::{Order, Query},
};

#[tokio::main]
async fn main() -> rs621::error::Result<()> {
let client = Client::new("https://e926.net", "MyProject/1.0 (by username on e621)")?;

println!("Top ten tags by post count!");

let result_stream = client
.tag_search(Query::new().limit(1).order(Order::Count))
.take(10);

futures::pin_mut!(result_stream);

while let Some(tag) = result_stream.next().await {
match tag {
Ok(tag) => println!("- {} with a score of {}", tag.name, tag.post_count),
Err(e) => println!("- couldn't load tag: {}", e),
}
}

Ok(())
}
76 changes: 76 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ mod rate_limit;
#[path = "client/dummy_rate_limit.rs"]
mod rate_limit;

use std::{fmt, num::ParseIntError, str::FromStr};

use futures::Future;
use reqwest::Url;
use serde_with::{DeserializeFromStr, SerializeDisplay};

use {
super::error::{Error, Result},
Expand Down Expand Up @@ -55,6 +58,43 @@ pub(crate) type QueryFuture = Box<dyn Future<Output = Result<serde_json::Value>>
#[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))]
pub(crate) type QueryFuture = Box<dyn Future<Output = Result<serde_json::Value>>>;

/// Where to begin returning results from in paginated requests.
#[derive(Debug, PartialEq, Eq, Clone, Copy, SerializeDisplay, DeserializeFromStr)]
pub enum Cursor {
/// Begin at the given page. Actual offset depends on page size.
Page(u64),

/// Return page size items ordered before the given id.
Before(u64),

/// Return page size items ordered after the given id.
After(u64),
}

impl FromStr for Cursor {
type Err = ParseIntError;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let result = match s.chars().next() {
Some('a') => Self::After(s[1..].parse()?),
Some('b') => Self::Before(s[1..].parse()?),
None | Some(_) => Self::Page(s.parse()?),
};

Ok(result)
}
}

impl fmt::Display for Cursor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Page(p) => write!(f, "{}", p),
Self::Before(p) => write!(f, "b{}", p),
Self::After(p) => write!(f, "a{}", p),
}
}
}

/// Client struct.
#[derive(Debug)]
pub struct Client {
Expand Down Expand Up @@ -179,6 +219,42 @@ impl Client {
.await
}

pub(crate) async fn get_json_endpoint_query<T, R>(&self, endpoint: &str, query: &T) -> Result<R>
where
T: serde::Serialize,
R: serde::de::DeserializeOwned,
{
let url = self.url(endpoint)?;
let future = self
.client
.get(url.clone())
.query(query)
.headers(self.headers.clone())
.send();

let res = self
.rate_limit
.clone()
.check(future)
.await
.map_err(|x| Error::CannotSendRequest(x.to_string()))?;

if res.status().is_success() {
res.json()
.await
.map_err(|e| Error::Serial(format!("{}", e)))
} else {
Err(Error::Http {
url,
code: res.status().as_u16(),
reason: match res.json::<serde_json::Value>().await {
Ok(v) => v["reason"].as_str().map(ToString::to_string),
Err(_) => None,
},
})
}
}

pub fn get_json_endpoint(
&self,
endpoint: &str,
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ pub mod post;

/// Pool management.
pub mod pool;

/// Tag management.
pub mod tag;
41 changes: 16 additions & 25 deletions src/post.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::error::Error;

use {
super::{
client::{Client, QueryFuture},
client::{Client, Cursor, QueryFuture},
error::Result as Rs621Result,
},
chrono::{offset::Utc, DateTime},
Expand Down Expand Up @@ -250,13 +250,6 @@ where
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum SearchPage {
Page(u64),
BeforePost(u64),
AfterPost(u64),
}

/// Iterator returning posts from a search query.
#[derive(Derivative)]
#[derivative(Debug)]
Expand All @@ -269,13 +262,13 @@ pub struct PostSearchStream<'a> {
#[derivative(Debug = "ignore")]
query_future: Option<Pin<QueryFuture>>,

next_page: SearchPage,
next_page: Cursor,
chunk: Vec<Rs621Result<Post>>,
ended: bool,
}

impl<'a> PostSearchStream<'a> {
fn new<T: Into<Query>>(client: &'a Client, query: T, page: SearchPage) -> Self {
fn new<T: Into<Query>>(client: &'a Client, query: T, page: Cursor) -> Self {
PostSearchStream {
client: client,
query: query.into(),
Expand Down Expand Up @@ -332,16 +325,14 @@ impl<'a> Stream for PostSearchStream<'a> {
// we now know what will be the next page
this.next_page = if this.query.ordered {
match this.next_page {
SearchPage::Page(i) => SearchPage::Page(i + 1),
_ => SearchPage::Page(1),
Cursor::Page(i) => Cursor::Page(i + 1),
_ => Cursor::Page(1),
}
} else {
match this.next_page {
SearchPage::Page(_) => SearchPage::BeforePost(last_id),
SearchPage::BeforePost(_) => {
SearchPage::BeforePost(last_id)
}
SearchPage::AfterPost(_) => SearchPage::AfterPost(last_id),
Cursor::Page(_) => Cursor::Before(last_id),
Cursor::Before(_) => Cursor::Before(last_id),
Cursor::After(_) => Cursor::After(last_id),
}
};

Expand Down Expand Up @@ -386,9 +377,9 @@ impl<'a> Stream for PostSearchStream<'a> {
"/posts.json?limit={}&page={}&tags={}",
ITER_CHUNK_SIZE,
match this.next_page {
SearchPage::Page(i) => format!("{}", i),
SearchPage::BeforePost(i) => format!("b{}", i),
SearchPage::AfterPost(i) => format!("a{}", i),
Cursor::Page(i) => format!("{}", i),
Cursor::Before(i) => format!("b{}", i),
Cursor::After(i) => format!("a{}", i),
},
this.query.url_encoded_tags
);
Expand Down Expand Up @@ -569,7 +560,7 @@ impl Client {
/// # Ok(()) }
/// ```
pub fn post_search<'a, T: Into<Query>>(&'a self, tags: T) -> PostSearchStream<'a> {
self.post_search_from_page(tags, SearchPage::Page(1))
self.post_search_from_page(tags, Cursor::Page(1))
}

/// Returns a Stream over all the posts matching the search query, starting from the given page.
Expand All @@ -579,13 +570,13 @@ impl Client {
/// # rs621::{client::Client, post::PostRating},
/// # futures::prelude::*,
/// # };
/// use rs621::post::SearchPage;
/// use rs621::client::Cursor;
/// # #[tokio::main]
/// # async fn main() -> rs621::error::Result<()> {
/// let client = Client::new("https://e926.net", "MyProject/1.0 (by username on e621)")?;
///
/// let mut post_stream = client
/// .post_search_from_page(&["fluffy", "rating:s"][..], SearchPage::BeforePost(123456))
/// .post_search_from_page(&["fluffy", "rating:s"][..], Cursor::Before(123456))
/// .take(3);
///
/// while let Some(post) = post_stream.next().await {
Expand All @@ -598,7 +589,7 @@ impl Client {
pub fn post_search_from_page<'a, T: Into<Query>>(
&'a self,
tags: T,
page: SearchPage,
page: Cursor,
) -> PostSearchStream<'a> {
PostSearchStream::new(self, tags, page)
}
Expand Down Expand Up @@ -856,7 +847,7 @@ mod tests {

assert_eq!(
client
.post_search_from_page(query, SearchPage::BeforePost(2269211))
.post_search_from_page(query, Cursor::Before(2269211))
.take(80)
.collect::<Vec<_>>()
.await,
Expand Down
Loading

0 comments on commit 93f7f5d

Please sign in to comment.