diff --git a/api/api.go b/api/api.go index 4ce19f4..878f657 100644 --- a/api/api.go +++ b/api/api.go @@ -201,8 +201,15 @@ func (a *API) initRouter() http.Handler { // get subscription info log.Infow("new route", "method", "GET", "path", planInfoEndpoint) r.Get(planInfoEndpoint, a.planInfoHandler) + // handle stripe webhook log.Infow("new route", "method", "POST", "path", subscriptionsWebhook) r.Post(subscriptionsWebhook, a.handleWebhook) + // handle stripe checkout session + log.Infow("new route", "method", "POST", "path", subscriptionsCheckout) + r.Post(subscriptionsCheckout, a.createSubscriptionCheckoutHandler) + // get stripe checkout session info + log.Infow("new route", "method", "GET", "path", subscriptionsCheckoutSession) + r.Get(subscriptionsCheckoutSession, a.checkoutSessionHandler) }) a.router = r return r diff --git a/api/docs.md b/api/docs.md index 4e96787..dcbfcb3 100644 --- a/api/docs.md +++ b/api/docs.md @@ -35,6 +35,9 @@ - [🏦 Plans](#-plans) - [🛒 Get Available Plans](#-get-plans) - [🛍️ Get Plan Info](#-get-plan-info) +- [Stripe](#-stripe) + - [] + - [] diff --git a/api/errors_definition.go b/api/errors_definition.go index 904a229..10747e2 100644 --- a/api/errors_definition.go +++ b/api/errors_definition.go @@ -54,4 +54,5 @@ var ( ErrGenericInternalServerError = Error{Code: 50002, HTTPstatus: http.StatusInternalServerError, Err: fmt.Errorf("internal server error")} ErrCouldNotCreateFaucetPackage = Error{Code: 50003, HTTPstatus: http.StatusInternalServerError, Err: fmt.Errorf("could not create faucet package")} ErrVochainRequestFailed = Error{Code: 50004, HTTPstatus: http.StatusInternalServerError, Err: fmt.Errorf("vochain request failed")} + ErrStripeError = Error{Code: 50005, HTTPstatus: http.StatusInternalServerError, Err: fmt.Errorf("stripe error")} ) diff --git a/api/organizations.go b/api/organizations.go index dddb6c6..db893cf 100644 --- a/api/organizations.go +++ b/api/organizations.go @@ -497,7 +497,7 @@ func (a *API) getOrganizationSubscriptionHandler(w http.ResponseWriter, r *http. return } if !org.Subscription.Active || - (org.Subscription.EndDate.After(time.Now()) && org.Subscription.StartDate.Before(time.Now())) { + org.Subscription.EndDate.Before(time.Now()) || org.Subscription.StartDate.After(time.Now()) { ErrOganizationSubscriptionIncative.Write(w) return } diff --git a/api/routes.go b/api/routes.go index 05ead71..126f9d4 100644 --- a/api/routes.go +++ b/api/routes.go @@ -66,4 +66,8 @@ const ( planInfoEndpoint = "/plans/{planID}" // POST /subscriptions/webhook to receive the subscription webhook from stripe subscriptionsWebhook = "/subscriptions/webhook" + // POST /subscriptions/checkout to create a new subscription + subscriptionsCheckout = "/subscriptions/checkout" + // GET /subscriptions/checkout/{sessionID} to get the checkout session information + subscriptionsCheckoutSession = "/subscriptions/checkout/{sessionID}" ) diff --git a/api/stripe.go b/api/stripe.go index 36b4997..1b8aee0 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -1,10 +1,13 @@ package api import ( + "encoding/json" "io" "net/http" + "strconv" "time" + "github.com/go-chi/chi/v5" "github.com/vocdoni/saas-backend/db" "go.vocdoni.io/dvote/log" ) @@ -84,6 +87,115 @@ func (a *API) handleWebhook(w http.ResponseWriter, r *http.Request) { return } log.Debugf("stripe webhook: subscription %s for organization %s processed successfully", subscription.ID, org.Address) + case "customer.subscription.updated", "customer.subscription.deleted": + customer, subscription, err := a.stripe.GetInfoFromEvent(*event) + if err != nil { + log.Errorf("stripe webhook: error getting info from event: %s\n", err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + address := subscription.Metadata["address"] + if len(address) == 0 { + log.Errorf("subscription %s does not contain an address in metadata", subscription.ID) + w.WriteHeader(http.StatusBadRequest) + return + } + org, _, err := a.db.Organization(address, false) + if err != nil || org == nil { + log.Errorf("could not update subscription %s, a corresponding organization with address %s was not found.", + subscription.ID, address) + log.Errorf("please do manually for creator %s \n Error: %s", customer.Email, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + if subscription.Status == "canceled" && org.Subscription.Active { + // replace organization subscription with the default plan + defaultPlan, err := a.db.DefaultPlan() + if err != nil || defaultPlan == nil { + ErrNoDefaultPLan.WithErr((err)).Write(w) + return + } + orgSubscription := &db.OrganizationSubscription{ + PlanID: defaultPlan.ID, + StartDate: time.Now(), + Active: true, + MaxCensusSize: defaultPlan.Organization.MaxCensus, + } + if err := a.db.SetOrganizationSubscription(org.Address, orgSubscription); err != nil { + log.Errorf("could not cancel subscription %s for organization %s: %s", subscription.ID, org.Address, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + } else if subscription.Status == "active" && !org.Subscription.Active { + org.Subscription.Active = true + if err := a.db.SetOrganization(org); err != nil { + log.Errorf("could activate organizations %s subscription to active: %s", org.Address, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + } + log.Debugf("stripe webhook: subscription %s for organization %s processed as %s successfully", + subscription.ID, org.Address, subscription.Status) } w.WriteHeader(http.StatusOK) } + +func (a *API) createSubscriptionCheckoutHandler(w http.ResponseWriter, r *http.Request) { + checkout := &SubscriptionCheckout{} + if err := json.NewDecoder(r.Body).Decode(checkout); err != nil { + ErrMalformedBody.Write(w) + return + } + + if checkout.LookupKey == "" || checkout.ReturnURL == "" || + checkout.Amount == "" || checkout.Address == "" { + ErrMalformedBody.Withf("Missing required fields").Write(w) + return + } + + lookupKey, err := strconv.ParseUint(checkout.LookupKey, 10, 64) + if err != nil { + ErrMalformedURLParam.Withf("Invalid plan lookup key: %v", err).Write(w) + return + } + + amount, err := strconv.ParseInt(checkout.Amount, 10, 64) + if err != nil { + ErrMalformedURLParam.Withf("Invalid census amount: %v", err).Write(w) + return + } + + plan, err := a.db.Plan(lookupKey) + if err != nil { + ErrMalformedURLParam.Withf("Plan not found: %v", err).Write(w) + return + } + + session, err := a.stripe.CreateSubscriptionCheckoutSession(plan.StripePriceID, checkout.ReturnURL, checkout.Address, amount) + if err != nil { + ErrStripeError.Withf("Cannot create session: %v", err).Write(w) + return + } + + data := &struct { + ClientSecret string `json:"clientSecret"` + }{ + ClientSecret: session.ClientSecret, + } + httpWriteJSON(w, data) +} + +func (a *API) checkoutSessionHandler(w http.ResponseWriter, r *http.Request) { + sessionID := chi.URLParam(r, "sessionID") + if sessionID == "" { + ErrMalformedURLParam.Withf("sessionID is required").Write(w) + return + } + status, err := a.stripe.RetrieveCheckoutSession(sessionID) + if err != nil { + ErrStripeError.Withf("Cannot get session: %v", err).Write(w) + return + } + + httpWriteJSON(w, status) +} diff --git a/api/types.go b/api/types.go index 3804bb1..db20f8f 100644 --- a/api/types.go +++ b/api/types.go @@ -183,3 +183,10 @@ type OrganizationSubscriptionInfo struct { Usage *db.OrganizationCounters `json:"usage"` Plan *db.Plan `json:"plan"` } + +type SubscriptionCheckout struct { + LookupKey string `json:"lookupKey"` + ReturnURL string `json:"returnURL"` + Amount string `json:"amount"` + Address string `json:"address"` +} diff --git a/db/types.go b/db/types.go index 71d1df8..04d632e 100644 --- a/db/types.go +++ b/db/types.go @@ -96,6 +96,7 @@ type Plan struct { ID uint64 `json:"id" bson:"_id"` Name string `json:"name" bson:"name"` StripeID string `json:"stripeID" bson:"stripeID"` + StripePriceID string `json:"stripePriceID" bson:"stripePriceID"` StartingPrice int64 `json:"startingPrice" bson:"startingPrice"` Default bool `json:"default" bson:"default"` Organization PlanLimits `json:"organization" bson:"organization"` diff --git a/stripe/stripe.go b/stripe/stripe.go index b30f080..6059bce 100644 --- a/stripe/stripe.go +++ b/stripe/stripe.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/stripe/stripe-go/v81" + "github.com/stripe/stripe-go/v81/checkout/session" "github.com/stripe/stripe-go/v81/customer" "github.com/stripe/stripe-go/v81/price" "github.com/stripe/stripe-go/v81/product" @@ -20,6 +21,12 @@ var ProductsIDs = []string{ "prod_RHurAb3OjkgJRy", // Custom } +type ReturnStatus struct { + Status string `json:"status"` + CustomerEmail string `json:"customer_email"` + SubscriptionStatus string `json:"subscription_status"` +} + // StripeClient is a client for interacting with the Stripe API. // It holds the necessary configuration such as the webhook secret. type StripeClient struct { @@ -138,7 +145,8 @@ func (s *StripeClient) GetPlans() ([]*db.Plan, error) { ID: uint64(i), Name: product.Name, StartingPrice: startingPrice, - StripeID: price.ID, + StripeID: productID, + StripePriceID: price.ID, Default: price.Metadata["Default"] == "true", Organization: organizationData, VotingTypes: votingTypesData, @@ -151,3 +159,57 @@ func (s *StripeClient) GetPlans() ([]*db.Plan, error) { } return plans, nil } + +func (s *StripeClient) CreateSubscriptionCheckoutSession( + priceID, returnURL, address string, amount int64, +) (*stripe.CheckoutSession, error) { + checkoutParams := &stripe.CheckoutSessionParams{ + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + LineItems: []*stripe.CheckoutSessionLineItemParams{ + { + Price: stripe.String(priceID), + AdjustableQuantity: &stripe.CheckoutSessionLineItemAdjustableQuantityParams{ + Enabled: stripe.Bool(true), + Minimum: stripe.Int64(1), + Maximum: stripe.Int64(1000), + }, + Quantity: stripe.Int64(amount), + }, + }, + UIMode: stripe.String(string(stripe.CheckoutSessionUIModeEmbedded)), + ReturnURL: stripe.String(returnURL + "/{CHECKOUT_SESSION_ID}"), + AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{ + Enabled: stripe.Bool(true), + }, + SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{ + Metadata: map[string]string{ + "address": address, + }, + }, + } + session, err := session.New(checkoutParams) + if err != nil { + return nil, err + } + + return session, nil +} + +// RetrieveCheckoutSession retrieves a checkout session from Stripe by session ID. +// It returns a ReturnStatus object and an error if any. +// The ReturnStatus object contains information about the session status, customer email, +// faucet package, recipient, and quantity. +func (s *StripeClient) RetrieveCheckoutSession(sessionID string) (*ReturnStatus, error) { + params := &stripe.CheckoutSessionParams{} + params.AddExpand("line_items") + sess, err := session.Get(sessionID, params) + if err != nil { + return nil, err + } + data := &ReturnStatus{ + Status: string(sess.Status), + CustomerEmail: sess.CustomerDetails.Email, + SubscriptionStatus: string(sess.Subscription.Status), + } + return data, nil +}