From b6c4be58c43ef914db8eb47975bd2626781b1840 Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Thu, 1 Apr 2021 15:13:06 +0200 Subject: [PATCH] add copy_both_simple method Signed-off-by: Petros Angelatos --- postgres-protocol/src/message/backend.rs | 33 ++ tokio-postgres/src/client.rs | 15 +- tokio-postgres/src/connection.rs | 20 ++ tokio-postgres/src/copy_both.rs | 390 +++++++++++++++++++++++ tokio-postgres/src/lib.rs | 1 + tokio-postgres/tests/test/copy_both.rs | 89 ++++++ tokio-postgres/tests/test/main.rs | 1 + 7 files changed, 547 insertions(+), 2 deletions(-) create mode 100644 tokio-postgres/src/copy_both.rs create mode 100644 tokio-postgres/tests/test/copy_both.rs diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 45e5c4074..22afd5ac4 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -22,6 +22,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -93,6 +94,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +192,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -524,6 +536,27 @@ impl CopyOutResponseBody { } } +pub struct CopyBothResponseBody { + format: u8, + len: u16, + storage: Bytes, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + pub struct DataRowBody { storage: Bytes, len: u16, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 2a865d68d..e89dbbeb1 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,7 @@ use crate::codec::BackendMessages; use crate::config::{Host, SslMode}; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; @@ -11,8 +12,9 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, + TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -461,6 +463,15 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy + /// data. + pub async fn copy_both_simple(&self, query: &str) -> Result, Error> + where + T: Buf + 'static + Send, + { + copy_both::copy_both_simple(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index b6805f76c..e98056dcc 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,4 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -21,6 +22,7 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), + CopyBoth(CopyBothReceiver), } pub struct Request { @@ -259,6 +261,24 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } + RequestMessages::CopyBoth(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_both request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_both stream"); + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..903b3ed8e --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,390 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{simple_query, Error}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures::channel::mpsc; +use futures::stream::FusedStream; +use futures::{ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// The state machine of CopyBothReceiver +/// +/// ```ignore +/// CopyBoth +/// / \ | +/// v v v +/// CopyOut CopyIn CopyError +/// \ / +/// v v +/// CopyNone +/// | +/// v +/// CopyComplete +/// | +/// v +/// CommandComplete +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CopyBothState { + /// Initial state where CopyData messages can go in both directions + CopyBoth, + /// The server->client stream is closed and we're in CopyIn mode + CopyIn, + /// The client->server stream is closed and we're in CopyOut mode + CopyOut, + /// Both directions are closed, we waiting for CommandComplete messages + CopyNone, + /// We have received the first CommandComplete message for the copy + CopyComplete, + /// We have received the final CommandComplete message for the statement + CommandComplete, + /// An error message from the server was received + CopyError, +} + +/// A CopyBothReceiver is responsible for handling the CopyBoth subprotocol. It ensures that no +/// matter what the users do with their CopyBothDuplex handle we're always going to send the +/// correct messages to the backend in order to restore the connection into a usable state. +/// +/// ```ignore +/// | +/// | +/// | +/// pg -> Connection -> CopyBothReceiver ---+---> CopyBothDuplex +/// | / \ +/// | v v +/// | Sink Stream +/// ``` +pub struct CopyBothReceiver { + /// Receiver of backend messages from the underlying [Connection](crate::Connection) + message_receiver: mpsc::Receiver>, + /// Receiver of frontend messages sent by the user using + sink_receiver: mpsc::Receiver, + /// Sender of CopyData contents to be consumed by the user using + stream_sender: mpsc::Sender>, + /// The current state of the subprotocol + state: CopyBothState, + /// Holds a buffered message until we are ready to send it to the user's stream + buffered_message: Option>, +} + +impl CopyBothReceiver { + pub(crate) fn new( + message_receiver: mpsc::Receiver>, + sink_receiver: mpsc::Receiver, + stream_sender: mpsc::Sender>, + ) -> CopyBothReceiver { + CopyBothReceiver { + message_receiver, + sink_receiver, + stream_sender, + state: CopyBothState::CopyBoth, + buffered_message: None, + } + } + + /// Convenience method to set the subprotocol into an unexpected message state + fn unexpected_message(&mut self) { + self.sink_receiver.close(); + self.buffered_message = Some(Err(Error::unexpected_message())); + self.state = CopyBothState::CopyError; + } + + /// Attempts to send a buffered message to the user. If the user has dropped their handle then + /// the message is simply dropped. + fn poll_stream(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if self.buffered_message.is_some() { + // If the receiver is gone we'll just drop the message + let _ = ready!(self.stream_sender.poll_ready(cx)); + let message = self.buffered_message.take().unwrap(); + let _ = self.stream_sender.start_send(message); + } + Poll::Ready(()) + } + + /// Attempts to receive a frontend message from the user. If the user has dropped their handle + /// then it simply returns None. + fn poll_sink(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.sink_receiver.is_terminated() { + match ready!(self.sink_receiver.poll_next_unpin(cx)) { + None => { + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + return Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))); + } + msg => return Poll::Ready(msg), + } + } + Poll::Ready(None) + } + + /// Processes messages from the backend and optionally produces a fronend message to send back + fn poll_backend(&mut self, cx: &mut Context<'_>) -> Poll> { + use CopyBothState::*; + + // Deliver the buffered message (if any) to the user to ensure we can potentially buffer a + // new one in response to a server message + ready!(self.poll_stream(cx)); + + match ready!(self.message_receiver.poll_next_unpin(cx)) { + Some(Ok(Message::CopyData(body))) => match self.state { + CopyBoth | CopyOut => { + self.buffered_message = Some(Ok(body.into_bytes())); + } + _ => self.unexpected_message(), + }, + // The server->client stream is done + Some(Ok(Message::CopyDone)) => { + match self.state { + CopyBoth => self.state = CopyIn, + CopyOut => self.state = CopyNone, + _ => self.unexpected_message(), + }; + } + // The server indicated an error, terminate our sides if we haven't already + Some(Ok(Message::ErrorResponse(error))) => { + match self.state { + CopyBoth | CopyOut | CopyIn => { + self.sink_receiver.close(); + self.buffered_message = Some(Err(Error::db(error))); + self.state = CopyError; + + let mut buf = BytesMut::new(); + frontend::sync(&mut buf); + return Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))); + } + _ => self.unexpected_message(), + }; + } + Some(Ok(Message::CommandComplete(_))) => { + match self.state { + CopyNone => self.state = CopyComplete, + CopyComplete => { + self.stream_sender.close_channel(); + self.sink_receiver.close(); + self.state = CommandComplete; + } + _ => self.unexpected_message(), + }; + } + Some(Err(err)) => { + self.buffered_message = Some(Err(err)); + self.state = CommandComplete; + } + Some(Ok(_)) => self.unexpected_message(), + None => {} + } + // Nothing to send + Poll::Ready(None) + } +} + +/// The [Connection](crate::Connection) will keep polling this stream until it is exhausted. This +/// is the mechanism that drives the CopyBoth subprotocol forward +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use CopyBothState::*; + + loop { + // If the server never sends the initial CopyBothResponse then the connection will + // never enter CopyBoth mode. In this case we return None immediately to get the + // protocol back into query processing mode + if self.message_receiver.is_terminated() { + return Poll::Ready(None); + } + + match self.state { + CopyBoth | CopyIn => { + match self.poll_sink(cx) { + Poll::Ready(Some(msg)) => return Poll::Ready(Some(msg)), + Poll::Ready(None) => { + self.state = match self.state { + CopyBoth => CopyOut, + CopyIn => CopyNone, + _ => unreachable!(), + } + } + // We don't want to wait if the user hasn't sent anything + Poll::Pending => {} + } + } + CopyError => { + // Send the stored error to the user before tearing down our send side + ready!(self.poll_stream(cx)); + self.stream_sender.close_channel(); + self.state = CommandComplete; + } + CopyOut | CopyNone | CopyComplete => {} + CommandComplete => return Poll::Ready(None), + } + + if let Some(msg) = ready!(self.poll_backend(cx)) { + return Poll::Ready(Some(msg)); + } + } + } +} + +pin_project! { + /// A duplex stream for consuming streaming replication data. + /// + /// Users should ensure that CopyBothDuplex is dropped before attempting to await on a new + /// query. This will ensure that the connection returns into normal processing mode. + /// + /// ```no_run + /// use tokio_postgres::Client; + /// + /// async fn foo(client: &Client) { + /// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await; + /// + /// // ⚠️ INCORRECT ⚠️ + /// client.query("SELECT 1", &[]).await; // hangs forever + /// + /// // duplex_stream drop-ed here + /// } + /// ``` + /// + /// ```no_run + /// use tokio_postgres::Client; + /// + /// async fn foo(client: &Client) { + /// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await; + /// + /// // ✅ CORRECT ✅ + /// drop(duplex_stream); + /// + /// client.query("SELECT 1", &[]).await; + /// } + /// ``` + pub struct CopyBothDuplex { + #[pin] + sink_sender: mpsc::Sender, + #[pin] + stream_receiver: mpsc::Receiver>, + buf: BytesMut, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl Stream for CopyBothDuplex { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream_receiver.poll_next(cx) + } +} + +impl Sink for CopyBothDuplex +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sink_sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sink_sender + .start_send(FrontendMessage::CopyData(data)) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sink_sender + .as_mut() + .start_send(FrontendMessage::CopyData(data)) + .map_err(|_| Error::closed())?; + } + + this.sink_sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + this.sink_sender.disconnect(); + Poll::Ready(Ok(())) + } +} + +pub async fn copy_both_simple( + client: &InnerClient, + query: &str, +) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy both query {}", query); + + let buf = simple_query::encode(client, query)?; + + let (message_sender, message_receiver) = mpsc::channel(0); + let (stream_sender, stream_receiver) = mpsc::channel(0); + let (mut sink_sender, sink_receiver) = mpsc::channel(0); + + let receiver = CopyBothReceiver::new(message_receiver, sink_receiver, stream_sender); + let mut responses = client.send(RequestMessages::CopyBoth(receiver))?; + + sink_sender + .send(FrontendMessage::Raw(buf)) + .await + .map_err(|_| Error::closed())?; + + match responses.next().await? { + Message::CopyBothResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + // We've entered CopyBoth mode, all backend messages should now go to the sub-protocol handler + // to ensure that we always send the correct messages regardless of what the user does with + // their CopyBothDuplex handle. + tokio::spawn(responses.map(Ok).forward(message_sender)); + + Ok(CopyBothDuplex { + stream_receiver, + sink_sender, + buf: BytesMut::new(), + _p: PhantomPinned, + _p2: PhantomData, + }) +} diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 6dd0b0151..3479d66c2 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -156,6 +156,7 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; diff --git a/tokio-postgres/tests/test/copy_both.rs b/tokio-postgres/tests/test/copy_both.rs new file mode 100644 index 000000000..3824e66d6 --- /dev/null +++ b/tokio-postgres/tests/test/copy_both.rs @@ -0,0 +1,89 @@ +use futures::{future, StreamExt, TryStreamExt}; +use crate::{error::SqlState, Client, NoTls, SimpleQueryMessage, SimpleQueryRow}; + +async fn q(client: &Client, query: &str) -> Vec { + let msgs = client.simple_query(query).await.unwrap(); + + msgs.into_iter() + .filter_map(|msg| match msg { + SimpleQueryMessage::Row(row) => Some(row), + _ => None, + }) + .collect() +} + +#[tokio::test] +async fn copy_both_error() { + let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; + let (client, connection) = crate::connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(connection); + + let err = client + .copy_both_simple::("START_REPLICATION SLOT undefined LOGICAL 0000/0000") + .await + .err() + .unwrap(); + + assert_eq!(err.code(), Some(&SqlState::UNDEFINED_OBJECT)); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} + +#[tokio::test] +async fn copy_both() { + let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; + let (client, connection) = crate::connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(connection); + + q(&client, "DROP TABLE IF EXISTS replication").await; + q(&client, "CREATE TABLE replication (i text)").await; + + let slot_query = "CREATE_REPLICATION_SLOT slot TEMPORARY LOGICAL \"test_decoding\""; + let lsn = q(&client, slot_query).await[0] + .get("consistent_point") + .unwrap() + .to_owned(); + + // We will attempt to read this from the other end + q(&client, "BEGIN").await; + let xid = q(&client, "SELECT txid_current()").await[0] + .get("txid_current") + .unwrap() + .to_owned(); + q(&client, "INSERT INTO replication VALUES ('processed')").await; + q(&client, "COMMIT").await; + + // Insert a second row to generate unprocessed messages in the stream + q(&client, "INSERT INTO replication VALUES ('ignored')").await; + + let query = format!("START_REPLICATION SLOT slot LOGICAL {}", lsn); + let duplex_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let expected = vec![ + format!("BEGIN {}", xid), + "table public.replication: INSERT: i[text]:'processed'".to_string(), + format!("COMMIT {}", xid), + ]; + + let actual: Vec<_> = duplex_stream + // Process only XLogData messages + .try_filter(|buf| future::ready(buf[0] == b'w')) + // Playback the stream until the first expected message + .try_skip_while(|buf| future::ready(Ok(!buf.ends_with(expected[0].as_ref())))) + // Take only the expected number of messsage, the rest will be discarded by tokio_postgres + .take(expected.len()) + .try_collect() + .await + .unwrap(); + + for (msg, ending) in actual.into_iter().zip(expected.into_iter()) { + assert!(msg.ends_with(ending.as_ref())); + } + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index c0b4bf202..76e129de6 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -17,6 +17,7 @@ use tokio_postgres::{ }; mod binary_copy; +mod copy_both; mod parse; #[cfg(feature = "runtime")] mod runtime;