diff --git a/async-nats/src/jetstream/consumer/pull.rs b/async-nats/src/jetstream/consumer/pull.rs index b3711842f..684d77f2e 100644 --- a/async-nats/src/jetstream/consumer/pull.rs +++ b/async-nats/src/jetstream/consumer/pull.rs @@ -1046,22 +1046,20 @@ impl futures::Stream for Stream { } if !self.batch_config.idle_heartbeat.is_zero() { - trace!("setting hearbeats"); - let timeout = self.batch_config.idle_heartbeat.saturating_mul(2); - self.heartbeat_timeout - .get_or_insert_with(|| Box::pin(tokio::time::sleep(timeout))); - trace!("checking idle hearbeats"); - if let Some(hearbeat) = self.heartbeat_timeout.as_mut() { - match hearbeat.poll_unpin(cx) { - Poll::Ready(_) => { - self.heartbeat_timeout = None; - return Poll::Ready(Some(Err(MessagesError::new( - MessagesErrorKind::MissingHeartbeat, - )))); - } - Poll::Pending => (), + let timeout = self.batch_config.idle_heartbeat.saturating_mul(2); + match self + .heartbeat_timeout + .get_or_insert_with(|| Box::pin(tokio::time::sleep(timeout))) + .poll_unpin(cx) + { + Poll::Ready(_) => { + self.heartbeat_timeout = None; + return Poll::Ready(Some(Err(MessagesError::new( + MessagesErrorKind::MissingHeartbeat, + )))); } + Poll::Pending => (), } } diff --git a/async-nats/src/jetstream/consumer/push.rs b/async-nats/src/jetstream/consumer/push.rs index e4278f685..0686ae96d 100644 --- a/async-nats/src/jetstream/consumer/push.rs +++ b/async-nats/src/jetstream/consumer/push.rs @@ -22,15 +22,14 @@ use crate::{ }; use bytes::Bytes; -use futures::future::BoxFuture; +use futures::{future::BoxFuture, FutureExt}; use serde::{Deserialize, Serialize}; #[cfg(feature = "server_2_10")] use std::collections::HashMap; use std::{ io::{self, ErrorKind}, pin::Pin, - sync::{Arc, Mutex}, - time::Instant, + sync::Arc, }; use std::{ sync::atomic::AtomicU64, @@ -41,6 +40,8 @@ use tokio::{sync::oneshot::error::TryRecvError, task::JoinHandle}; use tokio_retry::{strategy::ExponentialBackoff, Retry}; use tracing::{debug, trace}; +const ORDERED_IDLE_HEARTBEAT: Duration = Duration::from_secs(5); + impl Consumer { /// Returns a stream of messages for Push Consumer. /// @@ -105,7 +106,9 @@ impl Consumer { Ok(Messages { context: self.context.clone(), + config: self.config.clone(), subscriber, + heartbeat_sleep: None, }) } } @@ -113,42 +116,62 @@ impl Consumer { pub struct Messages { context: Context, subscriber: Subscriber, + config: Config, + heartbeat_sleep: Option>>, } impl futures::Stream for Messages { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + if !self.config.idle_heartbeat.is_zero() { + let heartbeat_sleep = self.config.idle_heartbeat.saturating_mul(2); + match self + .heartbeat_sleep + .get_or_insert_with(|| Box::pin(tokio::time::sleep(heartbeat_sleep))) + .poll_unpin(cx) + { + Poll::Ready(_) => { + return Poll::Ready(Some(Err(MessagesError::new( + MessagesErrorKind::MissingHeartbeat, + )))) + } + Poll::Pending => (), + } + } loop { match self.subscriber.receiver.poll_recv(cx) { - Poll::Ready(maybe_message) => match maybe_message { - Some(message) => match message.status { - Some(StatusCode::IDLE_HEARTBEAT) => { - if let Some(subject) = message.reply { - // TODO store pending_publish as a future and return errors from it - let client = self.context.client.clone(); - tokio::task::spawn(async move { - client - .publish(subject, Bytes::from_static(b"")) - .await - .unwrap(); - }); - } + Poll::Ready(maybe_message) => { + self.heartbeat_sleep = None; + match maybe_message { + Some(message) => match message.status { + Some(StatusCode::IDLE_HEARTBEAT) => { + if let Some(subject) = message.reply { + // TODO store pending_publish as a future and return errors from it + let client = self.context.client.clone(); + tokio::task::spawn(async move { + client + .publish(subject, Bytes::from_static(b"")) + .await + .unwrap(); + }); + } - continue; - } - Some(_) => { - continue; - } - None => { - return Poll::Ready(Some(Ok(jetstream::Message { - context: self.context.clone(), - message, - }))) - } - }, - None => return Poll::Ready(None), - }, + continue; + } + Some(_) => { + continue; + } + None => { + return Poll::Ready(Some(Ok(jetstream::Message { + context: self.context.clone(), + message, + }))) + } + }, + None => return Poll::Ready(None), + } + } Poll::Pending => return Poll::Pending, } } @@ -431,7 +454,7 @@ impl IntoConsumerConfig for OrderedConfig { max_ack_pending: 0, headers_only: self.headers_only, flow_control: true, - idle_heartbeat: Duration::from_secs(5), + idle_heartbeat: ORDERED_IDLE_HEARTBEAT, max_batch: 0, max_bytes: 0, max_expires: Duration::default(), @@ -454,12 +477,10 @@ impl Consumer { .await .map_err(|err| StreamError::with_source(StreamErrorKind::Other, err))?; - let last_seen = Arc::new(Mutex::new(Instant::now())); let last_sequence = Arc::new(AtomicU64::new(0)); let consumer_sequence = Arc::new(AtomicU64::new(0)); let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); let handle = tokio::task::spawn({ - let last_seen = last_seen.clone(); let stream_name = self.info.stream_name.clone(); let config = self.config.clone(); let mut context = self.context.clone(); @@ -470,31 +491,14 @@ impl Consumer { loop { let current_state = state.borrow().to_owned(); - match tokio::time::timeout( - Duration::from_secs(5), - context.client.state.changed(), - ) - .await + context.client.state.changed().await.unwrap(); + // State change notification received within the timeout + if state.borrow().to_owned() != State::Connected + || current_state == State::Connected { - Ok(_) => { - // State change notification received within the timeout - if state.borrow().to_owned() != State::Connected - || current_state == State::Connected - { - continue; - } - debug!("reconnected. trigger consumer recreation"); - } - Err(_) => { - debug!("heartbeat check"); - - if last_seen.lock().unwrap().elapsed() <= Duration::from_secs(10) { - trace!("last seen ok. wait"); - continue; - } - debug!("last seen not ok"); - } + continue; } + debug!("reconnected. trigger consumer recreation"); debug!( "idle heartbeats expired. recreating consumer s: {}, {:?}", @@ -514,7 +518,6 @@ impl Consumer { shutdown_tx.send(err).unwrap(); break; } - *last_seen.lock().unwrap() = Instant::now(); debug!("resetting consume sequence to 0"); consumer_sequence.store(0, Ordering::Relaxed); } @@ -528,9 +531,9 @@ impl Consumer { subscriber_future: None, stream_sequence: last_sequence, consumer_sequence, - last_seen, shutdown: shutdown_rx, handle, + heartbeat_sleep: None, }) } } @@ -542,9 +545,9 @@ pub struct Ordered<'a> { subscriber_future: Option>>, stream_sequence: Arc, consumer_sequence: Arc, - last_seen: Arc>, shutdown: tokio::sync::oneshot::Receiver, handle: JoinHandle<()>, + heartbeat_sleep: Option>>, } impl<'a> Drop for Ordered<'a> { @@ -558,6 +561,21 @@ impl<'a> futures::Stream for Ordered<'a> { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match self + .heartbeat_sleep + .get_or_insert_with(|| { + Box::pin(tokio::time::sleep(ORDERED_IDLE_HEARTBEAT.saturating_mul(2))) + }) + .poll_unpin(cx) + { + Poll::Ready(_) => { + return Poll::Ready(Some(Err(OrderedError::new( + OrderedErrorKind::MissingHeartbeat, + )))) + } + Poll::Pending => (), + } + loop { match self.shutdown.try_recv() { Ok(err) => { @@ -625,7 +643,7 @@ impl<'a> futures::Stream for Ordered<'a> { match subscriber.receiver.poll_recv(cx) { Poll::Ready(maybe_message) => match maybe_message { Some(message) => { - *self.last_seen.lock().unwrap() = Instant::now(); + self.heartbeat_sleep = None; match message.status { Some(StatusCode::IDLE_HEARTBEAT) => { debug!("received idle heartbeats"); diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index 369fe5d11..1289d1457 100644 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -41,7 +41,6 @@ mod jetstream { use async_nats::jetstream::stream::{self, DiscardPolicy, StorageType}; use async_nats::jetstream::AckKind; use async_nats::ConnectOptions; - use bytes::Bytes; use futures::stream::{StreamExt, TryStreamExt}; use time::OffsetDateTime; use tokio_retry::Retry; @@ -1190,7 +1189,7 @@ mod jetstream { let mut iter = consumer.sequence(50).unwrap().take(10); while let Ok(Some(mut batch)) = iter.try_next().await { while let Ok(Some(message)) = batch.try_next().await { - assert_eq!(message.payload, Bytes::from(b"dat".as_ref())); + assert_eq!(message.payload, bytes::Bytes::from(b"dat".as_ref())); } } } @@ -1633,6 +1632,24 @@ mod jetstream { seen += 1; } assert_eq!(seen, 1000); + + let consumer = stream + .create_consumer(consumer::push::Config { + deliver_subject: "delivery".to_string(), + durable_name: Some("delete_me".to_string()), + idle_heartbeat: Duration::from_secs(5), + ..Default::default() + }) + .await + .unwrap(); + + stream.delete_consumer("delete_me").await.unwrap(); + + let mut messages = consumer.messages().await.unwrap(); + assert_eq!( + messages.next().await.unwrap().unwrap_err().kind(), + async_nats::jetstream::consumer::push::MessagesErrorKind::MissingHeartbeat + ); } #[tokio::test]