Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Sender::extract_v2 for bindings #382

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 24 additions & 39 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,7 @@ impl App {
.extract_v2_req()
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
println!("Got a request from the sender. Responding with a Payjoin proposal.");
let http = http_agent()?;
let res = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
let res = post_request(req).await?;
payjoin_proposal
.process_res(res.bytes().await?.to_vec(), ohttp_ctx)
.map_err(|e| anyhow!("Failed to deserialize response {}", e))?;
Expand Down Expand Up @@ -197,31 +190,17 @@ impl App {
}

async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result<Psbt> {
let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?;
println!("Posting Original PSBT Payload request...");
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
println!("Sent fallback transaction");
match ctx {
payjoin::send::Context::V2(ctx) => {
match req_ctx.extract_v2(self.config.ohttp_relay.clone()) {
Ok((req, ctx)) => {
println!("Posting Original PSBT Payload request...");
let response = post_request(req).await?;
println!("Sent fallback transaction");
let v2_ctx = Arc::new(
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
);
loop {
let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
let response = post_request(req).await?;
match v2_ctx.process_response(
&mut response.bytes().await?.to_vec().as_slice(),
ohttp_ctx,
Expand All @@ -239,8 +218,12 @@ impl App {
}
}
}
payjoin::send::Context::V1(ctx) => {
match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Err(_) => {
let (req, v1_ctx) = req_ctx.extract_v1()?;
println!("Posting Original PSBT Payload request...");
let response = post_request(req).await?;
println!("Sent fallback transaction");
match v1_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Ok(psbt) => Ok(psbt),
Err(re) => {
println!("{}", re);
Expand All @@ -259,15 +242,7 @@ impl App {
loop {
let (req, context) = session.extract_req()?;
println!("Polling receive request...");
let http = http_agent()?;
let ohttp_response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;

let ohttp_response = post_request(req).await?;
let proposal = session
.process_res(ohttp_response.bytes().await?.to_vec().as_slice(), context)
.map_err(|_| anyhow!("GET fallback failed"))?;
Expand Down Expand Up @@ -407,6 +382,16 @@ async fn handle_interrupt(tx: watch::Sender<()>) {
let _ = tx.send(());
}

async fn post_request(req: payjoin::Request) -> Result<reqwest::Response> {
let http = http_agent()?;
http.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)
}

fn map_reqwest_err(e: reqwest::Error) -> anyhow::Error {
match e.status() {
Some(status_code) => anyhow!("HTTP request failed: {} {}", status_code, e),
Expand Down
48 changes: 9 additions & 39 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,46 +268,22 @@ impl Sender {
))
}

/// Extract serialized Request and Context from a Payjoin Proposal. Automatically selects the correct version.
///
/// In order to support polling, this may need to be called many times to be encrypted with
/// new unique nonces to make independent OHTTP requests.
/// Extract serialized Request and Context from a Payjoin Proposal.
///
/// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver
/// This method requires the `rs` pubkey to be extracted from the endpoint
/// and has no fallback to v1.
#[cfg(feature = "v2")]
pub fn extract_highest_version(
&mut self,
pub fn extract_v2(
&self,
ohttp_relay: Url,
) -> Result<(Request, Context), CreateRequestError> {
) -> Result<(Request, V2PostContext), CreateRequestError> {
use crate::uri::UrlExt;

if let Some(expiry) = self.endpoint.exp() {
if std::time::SystemTime::now() > expiry {
return Err(InternalCreateRequestError::Expired(expiry).into());
}
}

match self.extract_rs_pubkey() {
Ok(rs) => self.extract_v2(ohttp_relay, rs),
Err(e) => {
log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e);
let (req, context_v1) = self.extract_v1()?;
Ok((req, Context::V1(context_v1)))
}
}
}

/// Extract serialized Request and Context from a Payjoin Proposal.
///
/// This method requires the `rs` pubkey to be extracted from the endpoint
/// and has no fallback to v1.
#[cfg(feature = "v2")]
fn extract_v2(
&mut self,
ohttp_relay: Url,
rs: HpkePublicKey,
) -> Result<(Request, Context), CreateRequestError> {
use crate::uri::UrlExt;
let rs = self.extract_rs_pubkey()?;
let url = self.endpoint.clone();
let body = serialize_v2_body(
&self.psbt,
Expand All @@ -329,7 +305,7 @@ impl Sender {
log::debug!("ohttp_relay_url: {:?}", ohttp_relay);
Ok((
Request::new_v2(ohttp_relay, body),
Context::V2(V2PostContext {
V2PostContext {
endpoint: self.endpoint.clone(),
psbt_ctx: PsbtContext {
original_psbt: self.psbt.clone(),
Expand All @@ -341,7 +317,7 @@ impl Sender {
},
hpke_ctx,
ohttp_ctx,
}),
},
))
}

Expand All @@ -366,12 +342,6 @@ impl Sender {
pub fn endpoint(&self) -> &Url { &self.endpoint }
}

pub enum Context {
V1(V1Context),
#[cfg(feature = "v2")]
V2(V2PostContext),
}

#[derive(Debug, Clone)]
pub struct V1Context {
psbt_context: PsbtContext,
Expand Down
32 changes: 10 additions & 22 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ mod integration {
use bitcoin::Address;
use http::StatusCode;
use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal};
use payjoin::send::Context;
use payjoin::{OhttpKeys, PjUri, UriExt};
use reqwest::{Client, ClientBuilder, Error, Response};
use testcontainers_modules::redis::Redis;
Expand Down Expand Up @@ -285,9 +284,9 @@ mod integration {
Some(std::time::SystemTime::now()),
)
.build();
let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)?
let expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)?
.build_non_incentivizing(FeeRate::BROADCAST_MIN)?;
match expired_req_ctx.extract_highest_version(directory.to_owned()) {
match expired_req_ctx.extract_v2(directory.to_owned()) {
// Internal error types are private, so check against a string
Err(err) => assert!(err.to_string().contains("expired")),
_ => assert!(false, "Expired send session should error"),
Expand Down Expand Up @@ -355,14 +354,10 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_sweep_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (Request { url, body, content_type, .. }, send_ctx) =
req_ctx.extract_highest_version(directory.to_owned())?;
let send_ctx = match send_ctx {
Context::V2(ctx) => ctx,
_ => panic!("V2 context expected"),
};
req_ctx.extract_v2(directory.to_owned())?;
let response = agent
.post(url.clone())
.header("Content-Type", content_type)
Expand Down Expand Up @@ -521,10 +516,10 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_sweep_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (Request { url, body, content_type, .. }, post_ctx) =
req_ctx.extract_highest_version(directory.to_owned())?;
req_ctx.extract_v2(directory.to_owned())?;
let response = agent
.post(url.clone())
.header("Content-Type", content_type)
Expand All @@ -534,11 +529,8 @@ mod integration {
.unwrap();
log::info!("Response: {:#?}", &response);
assert!(response.status().is_success());
let get_ctx = match post_ctx {
Context::V2(ctx) =>
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
_ => panic!("V2 context expected"),
};
let get_ctx =
post_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?;
let (Request { url, body, content_type, .. }, ohttp_ctx) =
get_ctx.extract_req(directory.to_owned())?;
let response = agent
Expand Down Expand Up @@ -622,9 +614,9 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_original_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?;
let (req, ctx) = req_ctx.extract_v1()?;
let headers = HeaderMock::new(&req.body, req.content_type);

// **********************
Expand All @@ -636,10 +628,6 @@ mod integration {
// **********************
// Inside the Sender:
// Sender checks, signs, finalizes, extracts, and broadcasts
let ctx = match ctx {
Context::V1(ctx) => ctx,
_ => panic!("V1 context expected"),
};
let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?;
let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?;
sender.send_raw_transaction(&payjoin_tx)?;
Expand Down
Loading