From 5aea0456810b3e0f8fa29db43e423836c151fab4 Mon Sep 17 00:00:00 2001 From: Vikhyath Mondreti Date: Sun, 12 Apr 2026 12:04:45 -0700 Subject: [PATCH] fix(billing): unblock on payment success --- .../sim/lib/billing/webhooks/invoices.test.ts | 239 +++++++++++++++++ apps/sim/lib/billing/webhooks/invoices.ts | 249 +++++++++++------- 2 files changed, 398 insertions(+), 90 deletions(-) create mode 100644 apps/sim/lib/billing/webhooks/invoices.test.ts diff --git a/apps/sim/lib/billing/webhooks/invoices.test.ts b/apps/sim/lib/billing/webhooks/invoices.test.ts new file mode 100644 index 0000000000..63b14c04f8 --- /dev/null +++ b/apps/sim/lib/billing/webhooks/invoices.test.ts @@ -0,0 +1,239 @@ +/** + * @vitest-environment node + */ +import type Stripe from 'stripe' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockBlockOrgMembers, mockDbSelect, mockLogger, mockUnblockOrgMembers, selectResponses } = + vi.hoisted(() => { + const selectResponses: Array<{ limitResult?: unknown; whereResult?: unknown }> = [] + const mockDbSelect = vi.fn(() => { + const nextResponse = selectResponses.shift() + + if (!nextResponse) { + throw new Error('No queued db.select response') + } + + const builder = { + from: vi.fn(() => builder), + where: vi.fn(() => builder), + limit: vi.fn(async () => nextResponse.limitResult ?? nextResponse.whereResult ?? []), + then: (resolve: (value: unknown) => unknown, reject?: (reason: unknown) => unknown) => + Promise.resolve(nextResponse.whereResult ?? nextResponse.limitResult ?? []).then( + resolve, + reject + ), + } + + return builder + }) + + return { + mockBlockOrgMembers: vi.fn(), + mockDbSelect, + mockLogger: { + debug: vi.fn(), + error: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + }, + mockUnblockOrgMembers: vi.fn(), + selectResponses, + } + }) + +vi.mock('@sim/db', () => ({ + db: { + select: mockDbSelect, + }, +})) + +vi.mock('@sim/db/schema', () => ({ + member: { + organizationId: 'member.organizationId', + role: 'member.role', + userId: 'member.userId', + }, + organization: {}, + subscription: { + referenceId: 'subscription.referenceId', + stripeSubscriptionId: 'subscription.stripeSubscriptionId', + }, + user: { + email: 'user.email', + id: 'user.id', + name: 'user.name', + }, + userStats: { + billingBlocked: 'userStats.billingBlocked', + billingBlockedReason: 'userStats.billingBlockedReason', + userId: 'userStats.userId', + }, +})) + +vi.mock('@sim/logger', () => ({ + createLogger: vi.fn(() => mockLogger), +})) + +vi.mock('drizzle-orm', () => ({ + and: vi.fn(() => 'and'), + eq: vi.fn(() => 'eq'), + inArray: vi.fn(() => 'inArray'), + isNull: vi.fn(() => 'isNull'), + ne: vi.fn(() => 'ne'), + or: vi.fn(() => 'or'), +})) + +vi.mock('@/components/emails', () => ({ + PaymentFailedEmail: vi.fn(), + getEmailSubject: vi.fn(), + renderCreditPurchaseEmail: vi.fn(), +})) + +vi.mock('@/lib/billing/core/billing', () => ({ + calculateSubscriptionOverage: vi.fn(), +})) + +vi.mock('@/lib/billing/credits/balance', () => ({ + addCredits: vi.fn(), + getCreditBalance: vi.fn(), + removeCredits: vi.fn(), +})) + +vi.mock('@/lib/billing/credits/purchase', () => ({ + setUsageLimitForCredits: vi.fn(), +})) + +vi.mock('@/lib/billing/organizations/membership', () => ({ + blockOrgMembers: mockBlockOrgMembers, + unblockOrgMembers: mockUnblockOrgMembers, +})) + +vi.mock('@/lib/billing/plan-helpers', () => ({ + isEnterprise: vi.fn(() => false), + isOrgPlan: vi.fn((plan: string | null | undefined) => Boolean(plan?.startsWith('team'))), + isTeam: vi.fn((plan: string | null | undefined) => Boolean(plan?.startsWith('team'))), +})) + +vi.mock('@/lib/billing/stripe-client', () => ({ + requireStripeClient: vi.fn(), +})) + +vi.mock('@/lib/core/utils/urls', () => ({ + getBaseUrl: vi.fn(() => 'https://sim.test'), +})) + +vi.mock('@/lib/messaging/email/mailer', () => ({ + sendEmail: vi.fn(), +})) + +vi.mock('@/lib/messaging/email/utils', () => ({ + getPersonalEmailFrom: vi.fn(() => ({ + from: 'billing@sim.test', + replyTo: 'support@sim.test', + })), +})) + +vi.mock('@/lib/messaging/email/validation', () => ({ + quickValidateEmail: vi.fn(() => ({ isValid: true })), +})) + +vi.mock('@react-email/render', () => ({ + render: vi.fn(), +})) + +import { handleInvoicePaymentFailed, handleInvoicePaymentSucceeded } from './invoices' + +function queueSelectResponse(response: { limitResult?: unknown; whereResult?: unknown }) { + selectResponses.push(response) +} + +function createInvoiceEvent( + type: 'invoice.payment_failed' | 'invoice.payment_succeeded', + invoice: Partial +): Stripe.Event { + return { + data: { + object: invoice as Stripe.Invoice, + }, + id: `evt_${type}`, + type, + } as Stripe.Event +} + +describe('invoice billing recovery', () => { + beforeEach(() => { + vi.clearAllMocks() + selectResponses.length = 0 + mockBlockOrgMembers.mockResolvedValue(2) + mockUnblockOrgMembers.mockResolvedValue(2) + }) + + it('blocks org members when a metadata-backed invoice payment fails', async () => { + queueSelectResponse({ + limitResult: [ + { + id: 'sub-db-1', + plan: 'team_8000', + referenceId: 'org-1', + stripeSubscriptionId: 'sub_stripe_1', + }, + ], + }) + + await handleInvoicePaymentFailed( + createInvoiceEvent('invoice.payment_failed', { + amount_due: 3582, + attempt_count: 2, + customer: 'cus_123', + customer_email: 'owner@sim.test', + hosted_invoice_url: 'https://stripe.test/invoices/in_123', + id: 'in_123', + metadata: { + billingPeriod: '2026-04', + subscriptionId: 'sub_stripe_1', + type: 'overage_threshold_billing_org', + }, + }) + ) + + expect(mockBlockOrgMembers).toHaveBeenCalledWith('org-1', 'payment_failed') + expect(mockUnblockOrgMembers).not.toHaveBeenCalled() + }) + + it('unblocks org members when the matching metadata-backed invoice payment succeeds', async () => { + queueSelectResponse({ + limitResult: [ + { + id: 'sub-db-1', + plan: 'team_8000', + referenceId: 'org-1', + stripeSubscriptionId: 'sub_stripe_1', + }, + ], + }) + queueSelectResponse({ + whereResult: [{ userId: 'owner-1' }, { userId: 'member-1' }], + }) + queueSelectResponse({ + whereResult: [{ blocked: false }, { blocked: false }], + }) + + await handleInvoicePaymentSucceeded( + createInvoiceEvent('invoice.payment_succeeded', { + amount_paid: 3582, + billing_reason: 'manual', + customer: 'cus_123', + id: 'in_123', + metadata: { + billingPeriod: '2026-04', + subscriptionId: 'sub_stripe_1', + type: 'overage_threshold_billing_org', + }, + }) + ) + + expect(mockUnblockOrgMembers).toHaveBeenCalledWith('org-1', 'payment_failed') + expect(mockBlockOrgMembers).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/lib/billing/webhooks/invoices.ts b/apps/sim/lib/billing/webhooks/invoices.ts index 635aa8314a..398e40804c 100644 --- a/apps/sim/lib/billing/webhooks/invoices.ts +++ b/apps/sim/lib/billing/webhooks/invoices.ts @@ -24,7 +24,7 @@ import { quickValidateEmail } from '@/lib/messaging/email/validation' const logger = createLogger('StripeInvoiceWebhooks') -const OVERAGE_INVOICE_TYPES = new Set([ +const METADATA_SUBSCRIPTION_INVOICE_TYPES = new Set([ 'overage_billing', 'overage_threshold_billing', 'overage_threshold_billing_org', @@ -35,6 +35,116 @@ function parseDecimal(value: string | number | null | undefined): number { return Number.parseFloat(value.toString()) } +type InvoiceSubscriptionResolutionSource = + | 'parent.subscription_details.subscription' + | 'metadata.subscriptionId' + | 'none' + +interface InvoiceSubscriptionContext { + invoiceType: string | null + resolutionSource: InvoiceSubscriptionResolutionSource + stripeSubscriptionId: string | null +} + +type BillingSubscription = typeof subscriptionTable.$inferSelect + +interface ResolvedInvoiceSubscription extends InvoiceSubscriptionContext { + sub: BillingSubscription + stripeSubscriptionId: string +} + +function resolveInvoiceSubscriptionContext(invoice: Stripe.Invoice): InvoiceSubscriptionContext { + const invoiceType = invoice.metadata?.type ?? null + const canResolveFromMetadata = !!( + invoiceType && METADATA_SUBSCRIPTION_INVOICE_TYPES.has(invoiceType) + ) + const metadataSubscriptionId = + canResolveFromMetadata && + typeof invoice.metadata?.subscriptionId === 'string' && + invoice.metadata.subscriptionId.length > 0 + ? invoice.metadata.subscriptionId + : null + + const parentSubscription = invoice.parent?.subscription_details?.subscription + const parentSubscriptionId = + typeof parentSubscription === 'string' ? parentSubscription : (parentSubscription?.id ?? null) + + if ( + parentSubscriptionId && + metadataSubscriptionId && + parentSubscriptionId !== metadataSubscriptionId + ) { + logger.warn('Invoice has conflicting subscription identifiers', { + invoiceId: invoice.id, + invoiceType, + metadataSubscriptionId, + parentSubscriptionId, + }) + } + + if (parentSubscriptionId) { + return { + invoiceType, + resolutionSource: 'parent.subscription_details.subscription', + stripeSubscriptionId: parentSubscriptionId, + } + } + + if (metadataSubscriptionId) { + return { + invoiceType, + resolutionSource: 'metadata.subscriptionId', + stripeSubscriptionId: metadataSubscriptionId, + } + } + + return { + invoiceType, + resolutionSource: 'none', + stripeSubscriptionId: null, + } +} + +async function resolveInvoiceSubscription( + invoice: Stripe.Invoice, + handlerName: string +): Promise { + const subscriptionContext = resolveInvoiceSubscriptionContext(invoice) + + if (!subscriptionContext.stripeSubscriptionId) { + logger.info('No subscription found on invoice; skipping handler', { + handlerName, + invoiceId: invoice.id, + invoiceType: subscriptionContext.invoiceType, + resolutionSource: subscriptionContext.resolutionSource, + }) + return null + } + + const records = await db + .select() + .from(subscriptionTable) + .where(eq(subscriptionTable.stripeSubscriptionId, subscriptionContext.stripeSubscriptionId)) + .limit(1) + + if (records.length === 0) { + logger.warn('Subscription not found in database for invoice', { + handlerName, + invoiceId: invoice.id, + invoiceType: subscriptionContext.invoiceType, + resolutionSource: subscriptionContext.resolutionSource, + stripeSubscriptionId: subscriptionContext.stripeSubscriptionId, + }) + return null + } + + return { + ...subscriptionContext, + stripeSubscriptionId: subscriptionContext.stripeSubscriptionId, + sub: records[0], + } +} + /** * Create a billing portal URL for a Stripe customer */ @@ -462,21 +572,12 @@ export async function handleInvoicePaymentSucceeded(event: Stripe.Event) { return } - // Handle subscription invoices - const subscription = invoice.parent?.subscription_details?.subscription - const stripeSubscriptionId = typeof subscription === 'string' ? subscription : subscription?.id - if (!stripeSubscriptionId) { + const resolvedInvoice = await resolveInvoiceSubscription(invoice, 'invoice.payment_succeeded') + if (!resolvedInvoice) { return } - const records = await db - .select() - .from(subscriptionTable) - .where(eq(subscriptionTable.stripeSubscriptionId, stripeSubscriptionId)) - .limit(1) - - if (records.length === 0) return - const sub = records[0] + const { sub } = resolvedInvoice // Only reset usage here if the tenant was previously blocked; otherwise invoice.created already reset it let wasBlocked = false @@ -550,27 +651,13 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) { try { const invoice = event.data.object as Stripe.Invoice - const invoiceType = invoice.metadata?.type - const isOverageInvoice = !!(invoiceType && OVERAGE_INVOICE_TYPES.has(invoiceType)) - let stripeSubscriptionId: string | undefined - - if (isOverageInvoice) { - // Overage invoices store subscription ID in metadata - stripeSubscriptionId = invoice.metadata?.subscriptionId as string | undefined - } else { - // Regular subscription invoices have it in parent.subscription_details - const subscription = invoice.parent?.subscription_details?.subscription - stripeSubscriptionId = typeof subscription === 'string' ? subscription : subscription?.id - } - - if (!stripeSubscriptionId) { - logger.info('No subscription found on invoice; skipping payment failed handler', { - invoiceId: invoice.id, - isOverageInvoice, - }) + const resolvedInvoice = await resolveInvoiceSubscription(invoice, 'invoice.payment_failed') + if (!resolvedInvoice) { return } + const { invoiceType, resolutionSource, stripeSubscriptionId, sub } = resolvedInvoice + // Extract and validate customer ID const customerId = invoice.customer if (!customerId || typeof customerId !== 'string') { @@ -593,75 +680,57 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) { attemptCount, customerEmail: invoice.customer_email, hostedInvoiceUrl: invoice.hosted_invoice_url, - isOverageInvoice, - invoiceType: isOverageInvoice ? 'overage' : 'subscription', + invoiceType: invoiceType ?? 'subscription', + resolutionSource, }) // Block users after first payment failure if (attemptCount >= 1) { - const records = await db - .select() - .from(subscriptionTable) - .where(eq(subscriptionTable.stripeSubscriptionId, stripeSubscriptionId)) - .limit(1) - - if (records.length > 0) { - const sub = records[0] + logger.error('Payment failure - blocking users', { + customerId, + attemptCount, + invoiceId: invoice.id, + invoiceType: invoiceType ?? 'subscription', + resolutionSource, + stripeSubscriptionId, + }) - logger.error('Payment failure - blocking users', { - invoiceId: invoice.id, - customerId, - attemptCount, - isOverageInvoice, - stripeSubscriptionId, + if (isOrgPlan(sub.plan)) { + const memberCount = await blockOrgMembers(sub.referenceId, 'payment_failed') + logger.info('Blocked team/enterprise members due to payment failure', { + invoiceType: invoiceType ?? 'subscription', + memberCount, + organizationId: sub.referenceId, }) - - if (isOrgPlan(sub.plan)) { - const memberCount = await blockOrgMembers(sub.referenceId, 'payment_failed') - logger.info('Blocked team/enterprise members due to payment failure', { - organizationId: sub.referenceId, - memberCount, - isOverageInvoice, - }) - } else { - // Don't overwrite dispute blocks (dispute > payment_failed priority) - await db - .update(userStats) - .set({ billingBlocked: true, billingBlockedReason: 'payment_failed' }) - .where( - and( - eq(userStats.userId, sub.referenceId), - or( - ne(userStats.billingBlockedReason, 'dispute'), - isNull(userStats.billingBlockedReason) - ) + } else { + await db + .update(userStats) + .set({ billingBlocked: true, billingBlockedReason: 'payment_failed' }) + .where( + and( + eq(userStats.userId, sub.referenceId), + or( + ne(userStats.billingBlockedReason, 'dispute'), + isNull(userStats.billingBlockedReason) ) ) - logger.info('Blocked user due to payment failure', { - userId: sub.referenceId, - isOverageInvoice, - }) - } + ) + logger.info('Blocked user due to payment failure', { + invoiceType: invoiceType ?? 'subscription', + userId: sub.referenceId, + }) + } - // Send payment failure notification emails - // Only send on FIRST failure (attempt_count === 1), not on Stripe's automatic retries - // This prevents spamming users with duplicate emails every 3-5-7 days - if (attemptCount === 1) { - await sendPaymentFailureEmails(sub, invoice, customerId) - logger.info('Payment failure email sent on first attempt', { - invoiceId: invoice.id, - customerId, - }) - } else { - logger.info('Skipping payment failure email on retry attempt', { - invoiceId: invoice.id, - attemptCount, - customerId, - }) - } + if (attemptCount === 1) { + await sendPaymentFailureEmails(sub, invoice, customerId) + logger.info('Payment failure email sent on first attempt', { + customerId, + invoiceId: invoice.id, + }) } else { - logger.warn('Subscription not found in database for failed payment', { - stripeSubscriptionId, + logger.info('Skipping payment failure email on retry attempt', { + attemptCount, + customerId, invoiceId: invoice.id, }) }