diff --git a/src/goose.rs b/src/goose.rs index 638288fd..90c117e1 100644 --- a/src/goose.rs +++ b/src/goose.rs @@ -287,7 +287,8 @@ use downcast_rs::{impl_downcast, Downcast}; use http::method::Method; -use reqwest::{header, Client, ClientBuilder, RequestBuilder, Response}; +use reqwest::cookie::{CookieStore, Jar}; +use reqwest::{cookie, header, Client, ClientBuilder, RequestBuilder, Response}; use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -850,6 +851,8 @@ pub struct GooseUser { pub(crate) task_name: Option, session_data: Option>, + + cookie_store: Jar, } impl GooseUser { /// Create a new user state. @@ -879,6 +882,7 @@ impl GooseUser { slept: 0, task_name: None, session_data: None, + cookie_store: cookie::Jar::default(), }) } @@ -1533,8 +1537,7 @@ impl GooseUser { }; let started = Instant::now(); - let request = request_builder.build()?; - + let mut request = request_builder.build()?; // String version of request path. let path = match Url::parse(&request.url().to_string()) { Ok(u) => u.path().to_string(), @@ -1578,6 +1581,8 @@ impl GooseUser { self.weighted_users_index, ); + self.add_cookie_header(&mut request); + // Make the actual request. let response = self.client.execute(request).await; request_metric.set_response_time(started.elapsed().as_millis()); @@ -1638,6 +1643,14 @@ impl GooseUser { Ok(GooseResponse::new(request_metric, response)) } + fn add_cookie_header(&self, request: &mut reqwest::Request) { + if request.headers().get(header::COOKIE).is_none() { + if let Some(header) = self.cookie_store.cookies(request.url()) { + request.headers_mut().insert(header::COOKIE, header); + } + } + } + /// Tracks the time it takes for the current GooseUser to loop through all GooseTasks /// if Coordinated Omission Mitigation is enabled. pub(crate) async fn update_request_cadence(&mut self, thread_number: usize) { diff --git a/src/lib.rs b/src/lib.rs index 98ba0c25..07e8ca01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -459,7 +459,7 @@ use lazy_static::lazy_static; use nng::Socket; use rand::seq::SliceRandom; use rand::thread_rng; -use reqwest::Client; +use reqwest::{Client, ClientBuilder}; use std::collections::hash_map::DefaultHasher; use std::collections::BTreeMap; use std::hash::{Hash, Hasher}; @@ -822,7 +822,7 @@ impl GooseAttack { ) -> Result { let client = Client::builder() .user_agent(APP_USER_AGENT) - .cookie_store(true) + .cookie_store(false) // Enable gzip unless `--no-gzip` flag is enabled. .gzip(!configuration.no_gzip) .build()?; @@ -855,19 +855,18 @@ impl GooseAttack { /// /// #[tokio::main] /// async fn main() -> Result<(), GooseError> { - /// let client = Client::builder() - /// .build()?; + /// let client_builder = Client::builder(); /// /// GooseAttack::initialize()? /// .set_scheduler(GooseScheduler::Random) - /// .set_client(client); + /// .set_client(client_builder)?; /// /// Ok(()) /// } /// ``` - pub fn set_client(mut self, client: Client) -> Self { - self.client = client; - self + pub fn set_client(mut self, client_builder: ClientBuilder) -> Result { + self.client = client_builder.cookie_store(false).build()?; + Ok(self) } /// Define the order [`GooseTaskSet`](./goose/struct.GooseTaskSet.html)s are