diff --git a/api/server/handlers/billing/plan.go b/api/server/handlers/billing/plan.go index b321e77772..ea78fb4657 100644 --- a/api/server/handlers/billing/plan.go +++ b/api/server/handlers/billing/plan.go @@ -54,6 +54,22 @@ func (c *ListPlansHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { telemetry.AttributeKV{Key: "subscription_id", Value: plan.ID}, ) + endingBefore, err := c.Config().BillingManager.LagoClient.CheckCustomerCouponExpiration(ctx, proj.ID, proj.EnableSandbox) + if err != nil { + err := telemetry.Error(ctx, span, err, "error listing active coupons") + c.HandleAPIError(w, r, apierrors.NewErrInternal(err)) + return + } + + // If the customer has a coupon, use its end date instead of the trial end date + if endingBefore != "" { + plan.TrialInfo.EndingBefore = endingBefore + } + + telemetry.WithAttributes(span, + telemetry.AttributeKV{Key: "trial-ending-at", Value: plan.TrialInfo.EndingBefore}, + ) + c.WriteResult(w, r, plan) } diff --git a/api/types/billing_usage.go b/api/types/billing_usage.go index b5eb3d1922..3d4d657721 100644 --- a/api/types/billing_usage.go +++ b/api/types/billing_usage.go @@ -60,6 +60,14 @@ type BillingEvent struct { Timestamp string `json:"timestamp"` } +// AppliedCoupon represents an applied coupon in the billing system. +type AppliedCoupon struct { + Status string `json:"status"` + FrequencyDuration int `json:"frequency_duration"` + FrequencyDurationRemaining int `json:"frequency_duration_remaining"` + CreatedAt string `json:"created_at"` +} + // Wallet represents a customer credits wallet type Wallet struct { LagoID uuid.UUID `json:"lago_id,omitempty"` diff --git a/internal/billing/usage.go b/internal/billing/usage.go index 55b19b5d5e..50adcbc3fe 100644 --- a/internal/billing/usage.go +++ b/internal/billing/usage.go @@ -237,6 +237,28 @@ func (m LagoClient) ListCustomerCredits(ctx context.Context, projectID uint, san return response, nil } +func (m LagoClient) CheckCustomerCouponExpiration(ctx context.Context, projectID uint, sandboxEnabled bool) (trialEndDate string, err error) { + ctx, span := telemetry.NewSpan(ctx, "list-customer-coupons") + defer span.End() + + if projectID == 0 { + return trialEndDate, telemetry.Error(ctx, span, err, "project id empty") + } + customerID := m.generateLagoID(CustomerIDPrefix, projectID, sandboxEnabled) + couponList, err := m.listCustomerAppliedCoupons(ctx, customerID) + if err != nil { + return trialEndDate, telemetry.Error(ctx, span, err, "failed to list customer coupons") + } + + if len(couponList) == 0 { + return trialEndDate, nil + } + + appliedCoupon := couponList[0] + trialEndDate = time.Now().UTC().AddDate(0, appliedCoupon.FrequencyDurationRemaining, 0).Format(time.RFC3339) + return trialEndDate, nil +} + // CreateCreditsGrant will create a new credit grant for the customer with the specified amount func (m LagoClient) CreateCreditsGrant(ctx context.Context, projectID uint, name string, grantAmount int64, expiresAt *time.Time, sandboxEnabled bool) (err error) { ctx, span := telemetry.NewSpan(ctx, "create-credits-grant") @@ -498,7 +520,7 @@ func (m LagoClient) listCustomerWallets(ctx context.Context, customerID string) client := &http.Client{} resp, err := client.Do(req) if err != nil { - return walletList, telemetry.Error(ctx, span, err, "failed to get customer credits") + return walletList, telemetry.Error(ctx, span, err, "failed to get customer wallets") } response := struct { @@ -518,6 +540,43 @@ func (m LagoClient) listCustomerWallets(ctx context.Context, customerID string) return response.Wallets, nil } +func (m LagoClient) listCustomerAppliedCoupons(ctx context.Context, customerID string) (couponList []types.AppliedCoupon, err error) { + ctx, span := telemetry.NewSpan(ctx, "list-lago-customer-coupons") + defer span.End() + + // We manually do the request in this function because the Lago client has an issue + // with types for this specific request + url := fmt.Sprintf("%s/api/v1/applied_coupons?external_customer_id=%s&status=%s", lagoBaseURL, customerID, lago.AppliedCouponStatusActive) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return couponList, telemetry.Error(ctx, span, err, "failed to create coupons list request") + } + + req.Header.Set("Authorization", "Bearer "+m.lagoApiKey) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return couponList, telemetry.Error(ctx, span, err, "failed to get customer coupons") + } + + response := struct { + AppliedCoupons []types.AppliedCoupon `json:"applied_coupons"` + }{} + + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + return couponList, telemetry.Error(ctx, span, err, "failed to decode coupons list response") + } + + err = resp.Body.Close() + if err != nil { + return couponList, telemetry.Error(ctx, span, err, "failed to close response body") + } + + return response.AppliedCoupons, nil +} + func createUsageFromLagoUsage(lagoUsage lago.CustomerUsage) types.Usage { usage := types.Usage{} usage.FromDatetime = lagoUsage.FromDatetime.Format(time.RFC3339)