Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work with pools that don't support prepared statements #1147

Merged
merged 6 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,54 @@ impl Client {
query::query(&self.inner, statement, params).await
}

/// Like `query`, but requires the types of query parameters to be explicitly specified.
///
/// Compared to `query`, this method allows performing queries without three round trips (for
/// prepare, execute, and close) by requiring the caller to specify parameter values along with
/// their Postgres type. Thus, this is suitable in environments where prepared statements aren't
/// supported (such as Cloudflare Workers with Hyperdrive).
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the
/// parameter of the list provided, 1-indexed.
///
/// # Examples
///
/// ```no_run
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
/// use tokio_postgres::types::ToSql;
/// use tokio_postgres::types::Type;
/// use futures_util::{pin_mut, TryStreamExt};
///
/// let rows = client.query_typed(
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
/// ).await?;
///
/// for row in rows {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// # Ok(())
/// # }
/// ```
pub async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
fn slice_iter<'a>(
s: &'a [(&'a (dyn ToSql + Sync), Type)],
) -> impl ExactSizeIterator<Item = (&'a dyn ToSql, Type)> + 'a {
s.iter()
.map(|(param, param_type)| (*param as _, param_type.clone()))
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to factor this into a separate function here since it's only being called one place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Earlier thought was to allow access to the raw RowStream.


query::query_typed(&self.inner, statement, slice_iter(params))
.await?
.try_collect()
.await
}

/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
Expand Down
23 changes: 23 additions & 0 deletions tokio-postgres/src/generic_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ pub trait GenericClient: private::Sealed {
I: IntoIterator<Item = P> + Sync + Send,
I::IntoIter: ExactSizeIterator;

/// Like [`Client::query_typed`]
async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error>;

/// Like [`Client::prepare`].
async fn prepare(&self, query: &str) -> Result<Statement, Error>;

Expand Down Expand Up @@ -139,6 +146,14 @@ impl GenericClient for Client {
self.query_raw(statement, params).await
}

async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_typed(statement, params).await
}

async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare(query).await
}
Expand Down Expand Up @@ -229,6 +244,14 @@ impl GenericClient for Transaction<'_> {
self.query_raw(statement, params).await
}

async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_typed(statement, params).await
}

async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare(query).await
}
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
})
}

async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
pub(crate) async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
Expand Down
95 changes: 92 additions & 3 deletions tokio-postgres/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::prepare::get_type;
use crate::types::{BorrowToSql, IsNull};
use crate::{Error, Portal, Row, Statement};
use crate::{Column, Error, Portal, Row, Statement};
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::{debug, log_enabled, Level};
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
use postgres_protocol::message::frontend;
use postgres_types::Type;
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
Expand Down Expand Up @@ -57,6 +61,71 @@ where
})
}

pub async fn query_typed<'a, P, I>(
client: &Arc<InnerClient>,
query: &str,
params: I,
) -> Result<RowStream, Error>
where
P: BorrowToSql,
I: IntoIterator<Item = (P, Type)>,
I::IntoIter: ExactSizeIterator,
{
let (params, param_types): (Vec<_>, Vec<_>) = params.into_iter().unzip();

let params = params.into_iter();

let param_oids = param_types.iter().map(|t| t.oid()).collect::<Vec<_>>();

let params = params.into_iter();

let buf = client.with_buf(|buf| {
frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;

encode_bind_with_statement_name_and_param_types("", &param_types, params, "", buf)?;

frontend::describe(b'S', "", buf).map_err(Error::encode)?;

frontend::execute("", 0, buf).map_err(Error::encode)?;

frontend::sync(buf);

Ok(buf.split().freeze())
})?;

let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;

loop {
match responses.next().await? {
Message::ParseComplete
| Message::BindComplete
| Message::ParameterDescription(_)
| Message::NoData => {}
Message::RowDescription(row_description) => {
let mut columns: Vec<Column> = vec![];
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, field.type_oid()).await?;
let column = Column {
name: field.name().to_string(),
table_oid: Some(field.table_oid()).filter(|n| *n != 0),
column_id: Some(field.column_id()).filter(|n| *n != 0),
r#type: type_,
};
columns.push(column);
}
return Ok(RowStream {
statement: Statement::unnamed(vec![], columns),
responses,
rows_affected: None,
_p: PhantomPinned,
});
}
_ => return Err(Error::unexpected_message()),
}
}
}

pub async fn query_portal(
client: &InnerClient,
portal: &Portal,
Expand Down Expand Up @@ -164,7 +233,27 @@ where
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let param_types = statement.params();
encode_bind_with_statement_name_and_param_types(
statement.name(),
statement.params(),
params,
portal,
buf,
)
}

fn encode_bind_with_statement_name_and_param_types<P, I>(
statement_name: &str,
param_types: &[Type],
params: I,
portal: &str,
buf: &mut BytesMut,
) -> Result<(), Error>
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let params = params.into_iter();

if param_types.len() != params.len() {
Expand All @@ -181,7 +270,7 @@ where
let mut error_idx = 0;
let r = frontend::bind(
portal,
statement.name(),
statement_name,
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
Expand Down
13 changes: 13 additions & 0 deletions tokio-postgres/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ struct StatementInner {

impl Drop for StatementInner {
fn drop(&mut self) {
if self.name.is_empty() {
// Unnamed statements don't need to be closed
return;
}
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', &self.name, buf).unwrap();
Expand Down Expand Up @@ -46,6 +50,15 @@ impl Statement {
}))
}

pub(crate) fn unnamed(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
name: String::new(),
params,
columns,
}))
}

pub(crate) fn name(&self) -> &str {
&self.0.name
}
Expand Down
9 changes: 9 additions & 0 deletions tokio-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,15 @@ impl<'a> Transaction<'a> {
query::query_portal(self.client.inner(), portal, max_rows).await
}

/// Like `Client::query_typed`.
pub async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.client.query_typed(statement, params).await
}

/// Like `Client::copy_in`.
pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
where
Expand Down
106 changes: 106 additions & 0 deletions tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -952,3 +952,109 @@ async fn deferred_constraint() {
.await
.unwrap_err();
}

#[tokio::test]
async fn query_typed_no_transaction() {
let client = connect("user=postgres").await;

client
.batch_execute(
"
CREATE TEMPORARY TABLE foo (
name TEXT,
age INT
);
INSERT INTO foo (name, age) VALUES ('alice', 20), ('bob', 30), ('carol', 40);
",
)
.await
.unwrap();

let rows: Vec<tokio_postgres::Row> = client
.query_typed(
"SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age",
&[(&"alice", Type::TEXT), (&50i32, Type::INT4)],
)
.await
.unwrap();

assert_eq!(rows.len(), 2);
let first_row = &rows[0];
assert_eq!(first_row.get::<_, &str>(0), "bob");
assert_eq!(first_row.get::<_, i32>(1), 30);
assert_eq!(first_row.get::<_, &str>(2), "literal");
assert_eq!(first_row.get::<_, i32>(3), 5);

let second_row = &rows[1];
assert_eq!(second_row.get::<_, &str>(0), "carol");
assert_eq!(second_row.get::<_, i32>(1), 40);
assert_eq!(second_row.get::<_, &str>(2), "literal");
assert_eq!(second_row.get::<_, i32>(3), 5);
}

#[tokio::test]
async fn query_typed_with_transaction() {
let mut client = connect("user=postgres").await;

client
.batch_execute(
"
CREATE TEMPORARY TABLE foo (
name TEXT,
age INT
);
",
)
.await
.unwrap();

let transaction = client.transaction().await.unwrap();

let rows: Vec<tokio_postgres::Row> = transaction
.query_typed(
"INSERT INTO foo (name, age) VALUES ($1, $2), ($3, $4), ($5, $6) returning name, age",
&[
(&"alice", Type::TEXT),
(&20i32, Type::INT4),
(&"bob", Type::TEXT),
(&30i32, Type::INT4),
(&"carol", Type::TEXT),
(&40i32, Type::INT4),
],
)
.await
.unwrap();
let inserted_values: Vec<(String, i32)> = rows
.iter()
.map(|row| (row.get::<_, String>(0), row.get::<_, i32>(1)))
.collect();
assert_eq!(
inserted_values,
[
("alice".to_string(), 20),
("bob".to_string(), 30),
("carol".to_string(), 40)
]
);

let rows: Vec<tokio_postgres::Row> = transaction
.query_typed(
"SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age",
&[(&"alice", Type::TEXT), (&50i32, Type::INT4)],
)
.await
.unwrap();

assert_eq!(rows.len(), 2);
let first_row = &rows[0];
assert_eq!(first_row.get::<_, &str>(0), "bob");
assert_eq!(first_row.get::<_, i32>(1), 30);
assert_eq!(first_row.get::<_, &str>(2), "literal");
assert_eq!(first_row.get::<_, i32>(3), 5);

let second_row = &rows[1];
assert_eq!(second_row.get::<_, &str>(0), "carol");
assert_eq!(second_row.get::<_, i32>(1), 40);
assert_eq!(second_row.get::<_, &str>(2), "literal");
assert_eq!(second_row.get::<_, i32>(3), 5);
}
Loading