diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index b3ab87e0ea..c3773b4764 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -18,7 +18,8 @@ use stripe::{ CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm, CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems, CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject, - EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus, + EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, + SubscriptionStatus, }; use util::{ResultExt, maybe}; @@ -445,6 +446,7 @@ struct ManageBillingSubscriptionBody { intent: ManageSubscriptionIntent, /// The ID of the subscription to manage. subscription_id: BillingSubscriptionId, + redirect_to: Option, } #[derive(Debug, Serialize)] @@ -542,6 +544,23 @@ async fn manage_billing_subscription( .map_or(false, |price| price.id == zed_pro_price_id) }); if is_on_zed_pro_trial { + let payment_methods = PaymentMethod::list( + &stripe_client, + &stripe::ListPaymentMethods { + customer: Some(stripe_subscription.customer.id()), + ..Default::default() + }, + ) + .await?; + + let has_payment_method = !payment_methods.data.is_empty(); + if !has_payment_method { + return Err(Error::http( + StatusCode::BAD_REQUEST, + "missing payment method".into(), + )); + } + // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early. Subscription::update( &stripe_client, @@ -596,7 +615,11 @@ async fn manage_billing_subscription( after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect, redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect { - return_url: format!("{}/account", app.config.zed_dot_dev_url()), + return_url: format!( + "{}{path}", + app.config.zed_dot_dev_url(), + path = body.redirect_to.unwrap_or_else(|| "/account".to_string()) + ), }), ..Default::default() }),