diff --git a/apps/sim/lib/billing/constants.ts b/apps/sim/lib/billing/constants.ts index d9a3c39154..ce20115db8 100644 --- a/apps/sim/lib/billing/constants.ts +++ b/apps/sim/lib/billing/constants.ts @@ -34,6 +34,11 @@ export const SEARCH_TOOL_COST = 0.01 */ export const DEFAULT_OVERAGE_THRESHOLD = 100 +/** + * Maximum time to wait on billing coordination row locks before retrying later. + */ +export const BILLING_LOCK_TIMEOUT_MS = 5_000 + /** * Available credit tiers. Each tier maps a credit amount to the underlying dollar cost. * 1 credit = $0.005, so credits = dollars * 200. diff --git a/apps/sim/lib/billing/organizations/membership.ts b/apps/sim/lib/billing/organizations/membership.ts index 93add34dff..9551396e24 100644 --- a/apps/sim/lib/billing/organizations/membership.ts +++ b/apps/sim/lib/billing/organizations/membership.ts @@ -926,34 +926,6 @@ export async function removeUserFromOrganization( ) } - let capturedUsage = 0 - if (!skipBillingLogic) { - const [departingUserStats] = await tx - .select({ currentPeriodCost: userStats.currentPeriodCost }) - .from(userStats) - .where(eq(userStats.userId, userId)) - .limit(1) - - if (departingUserStats?.currentPeriodCost) { - const usage = toNumber(toDecimal(departingUserStats.currentPeriodCost)) - if (usage > 0) { - await tx - .update(organization) - .set({ - departedMemberUsage: sql`${organization.departedMemberUsage} + ${usage}`, - }) - .where(eq(organization.id, organizationId)) - - await tx - .update(userStats) - .set({ currentPeriodCost: '0' }) - .where(eq(userStats.userId, userId)) - - capturedUsage = usage - } - } - } - const [targetUser] = await tx .select({ email: user.email }) .from(user) @@ -979,7 +951,44 @@ export async function removeUserFromOrganization( .from(workspace) .where(eq(workspace.organizationId, organizationId)) + const captureDepartedUsage = async () => { + if (skipBillingLogic) return 0 + + await tx + .select({ id: organization.id }) + .from(organization) + .where(eq(organization.id, organizationId)) + .for('update') + .limit(1) + + const [departingUserStats] = await tx + .select({ currentPeriodCost: userStats.currentPeriodCost }) + .from(userStats) + .where(eq(userStats.userId, userId)) + .for('update') + .limit(1) + + const usage = toNumber(toDecimal(departingUserStats?.currentPeriodCost)) + if (usage <= 0) return 0 + + await tx + .update(organization) + .set({ + departedMemberUsage: sql`${organization.departedMemberUsage} + ${usage}`, + }) + .where(eq(organization.id, organizationId)) + + await tx + .update(userStats) + .set({ currentPeriodCost: '0' }) + .where(eq(userStats.userId, userId)) + + return usage + } + if (orgWorkspaces.length === 0) { + const capturedUsage = await captureDepartedUsage() + return { workspaceIdsToRevoke: [] as string[], usageCaptured: capturedUsage, @@ -1022,6 +1031,7 @@ export async function removeUserFromOrganization( workspaceIds, userId, }) + const capturedUsage = await captureDepartedUsage() return { workspaceIdsToRevoke: deletedPerms.map((row) => row.entityId), diff --git a/apps/sim/lib/billing/threshold-billing.test.ts b/apps/sim/lib/billing/threshold-billing.test.ts new file mode 100644 index 0000000000..042dabaebf --- /dev/null +++ b/apps/sim/lib/billing/threshold-billing.test.ts @@ -0,0 +1,528 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCalculateSubscriptionOverage, + mockComputeOrgOverageAmount, + mockDbSelect, + mockDbTransaction, + mockEnqueueOutboxEvent, + mockGetEffectiveBillingStatus, + mockGetHighestPrioritySubscription, + mockGetOrganizationSubscriptionUsable, + mockHasUsableSubscriptionAccess, + mockIsEnterprise, + mockIsFree, + mockIsOrgScopedSubscription, + mockIsOrganizationBillingBlocked, + mockTxExecute, + mockTxSelect, + mockTxStatsLimit, + mockTxUpdate, +} = vi.hoisted(() => ({ + mockCalculateSubscriptionOverage: vi.fn(), + mockComputeOrgOverageAmount: vi.fn(), + mockDbSelect: vi.fn(), + mockDbTransaction: vi.fn(), + mockEnqueueOutboxEvent: vi.fn(), + mockGetEffectiveBillingStatus: vi.fn(), + mockGetHighestPrioritySubscription: vi.fn(), + mockGetOrganizationSubscriptionUsable: vi.fn(), + mockHasUsableSubscriptionAccess: vi.fn(), + mockIsEnterprise: vi.fn(), + mockIsFree: vi.fn(), + mockIsOrgScopedSubscription: vi.fn(), + mockIsOrganizationBillingBlocked: vi.fn(), + mockTxExecute: vi.fn(), + mockTxSelect: vi.fn(), + mockTxStatsLimit: vi.fn(), + mockTxUpdate: vi.fn(), +})) + +vi.mock('@sim/db', () => ({ + db: { + select: mockDbSelect, + transaction: mockDbTransaction, + }, +})) + +vi.mock('@sim/db/schema', () => ({ + member: { + organizationId: 'member.organizationId', + role: 'member.role', + userId: 'member.userId', + }, + organization: { + creditBalance: 'organization.creditBalance', + departedMemberUsage: 'organization.departedMemberUsage', + id: 'organization.id', + }, + subscription: { + id: 'subscription.id', + stripeCustomerId: 'subscription.stripeCustomerId', + }, + userStats: { + billedOverageThisPeriod: 'userStats.billedOverageThisPeriod', + creditBalance: 'userStats.creditBalance', + currentPeriodCost: 'userStats.currentPeriodCost', + lastPeriodCost: 'userStats.lastPeriodCost', + proPeriodCostSnapshot: 'userStats.proPeriodCostSnapshot', + proPeriodCostSnapshotAt: 'userStats.proPeriodCostSnapshotAt', + userId: 'userStats.userId', + }, +})) + +vi.mock('@/lib/billing/core/access', () => ({ + getEffectiveBillingStatus: mockGetEffectiveBillingStatus, + isOrganizationBillingBlocked: mockIsOrganizationBillingBlocked, +})) + +vi.mock('@/lib/billing/core/billing', () => ({ + calculateSubscriptionOverage: mockCalculateSubscriptionOverage, + computeOrgOverageAmount: mockComputeOrgOverageAmount, +})) + +vi.mock('@/lib/billing/core/subscription', () => ({ + getHighestPrioritySubscription: mockGetHighestPrioritySubscription, + getOrganizationSubscriptionUsable: mockGetOrganizationSubscriptionUsable, +})) + +vi.mock('@/lib/billing/plan-helpers', () => ({ + isEnterprise: mockIsEnterprise, + isFree: mockIsFree, +})) + +vi.mock('@/lib/billing/subscriptions/utils', () => ({ + hasUsableSubscriptionAccess: mockHasUsableSubscriptionAccess, + isOrgScopedSubscription: mockIsOrgScopedSubscription, +})) + +vi.mock('@/lib/billing/webhooks/outbox-handlers', () => ({ + OUTBOX_EVENT_TYPES: { + STRIPE_THRESHOLD_OVERAGE_INVOICE: 'stripe.threshold-overage-invoice', + }, +})) + +vi.mock('@/lib/core/config/env', () => ({ + env: {}, + envNumber: vi.fn((_value: string | undefined, fallback: number) => fallback), +})) + +vi.mock('@/lib/core/outbox/service', () => ({ + enqueueOutboxEvent: mockEnqueueOutboxEvent, +})) + +import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing' + +interface MockTx { + execute: typeof mockTxExecute + select: typeof mockTxSelect + update: typeof mockTxUpdate +} + +const userSubscription = { + id: 'sub-db-1', + plan: 'pro', + referenceId: 'user-1', + seats: 1, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_stripe_1', + status: 'active', +} + +function buildSelectChain(rows: T[]) { + const chain = { + from: vi.fn(() => chain), + leftJoin: vi.fn(() => chain), + innerJoin: vi.fn(() => chain), + where: vi.fn(() => result), + } + const result = { + limit: vi.fn(async () => rows), + then: (resolve: (value: T[]) => unknown, reject?: (reason: unknown) => unknown) => + Promise.resolve(rows).then(resolve, reject), + } + + return { + from: chain.from, + } +} + +function buildPersonalSelectChain(customerId = 'cus_1') { + return buildSelectChain([ + { + currentPeriodCost: '0', + proPeriodCostSnapshot: '0', + proPeriodCostSnapshotAt: null, + lastPeriodCost: '0', + stripeCustomerId: customerId, + }, + ]) +} + +function buildPersonalSnapshotSelectChain({ + currentPeriodCost = '0', + proPeriodCostSnapshot = '0', + proPeriodCostSnapshotAt = null, + lastPeriodCost = '0', +}: { + currentPeriodCost?: string + proPeriodCostSnapshot?: string + proPeriodCostSnapshotAt?: Date | null + lastPeriodCost?: string +}) { + return buildSelectChain([ + { + currentPeriodCost, + proPeriodCostSnapshot, + proPeriodCostSnapshotAt, + lastPeriodCost, + }, + ]) +} + +function buildStatsSelectChain() { + const result = { + for: vi.fn(() => result), + limit: mockTxStatsLimit, + then: (resolve: (value: unknown[]) => unknown, reject?: (reason: unknown) => unknown) => + Promise.resolve(mockTxStatsLimit()).then(resolve, reject), + } + + return { + from: vi.fn(() => ({ + leftJoin: vi.fn(() => ({ + innerJoin: vi.fn(() => ({ + where: vi.fn(() => result), + })), + })), + where: vi.fn(() => result), + })), + } +} + +function buildUpdateChain() { + return { + set: vi.fn(() => ({ + where: vi.fn(async () => []), + })), + } +} + +describe('checkAndBillOverageThreshold', () => { + beforeEach(() => { + vi.clearAllMocks() + + mockGetHighestPrioritySubscription.mockResolvedValue(userSubscription) + mockGetEffectiveBillingStatus.mockResolvedValue({ billingBlocked: false }) + mockHasUsableSubscriptionAccess.mockReturnValue(true) + mockIsFree.mockReturnValue(false) + mockIsEnterprise.mockReturnValue(false) + mockIsOrgScopedSubscription.mockReturnValue(false) + mockDbSelect.mockImplementation(() => buildPersonalSelectChain()) + mockTxSelect.mockImplementation(() => buildStatsSelectChain()) + mockTxUpdate.mockImplementation(() => buildUpdateChain()) + mockTxExecute.mockResolvedValue(undefined) + mockDbTransaction.mockImplementation(async (callback: (tx: MockTx) => Promise) => + callback({ execute: mockTxExecute, select: mockTxSelect, update: mockTxUpdate }) + ) + }) + + it('does not lock user_stats when calculated overage is below threshold', async () => { + mockCalculateSubscriptionOverage.mockResolvedValue(99) + + await checkAndBillOverageThreshold('user-1') + + expect(mockCalculateSubscriptionOverage).toHaveBeenCalledWith({ + id: userSubscription.id, + plan: userSubscription.plan, + referenceId: userSubscription.referenceId, + seats: userSubscription.seats, + periodStart: userSubscription.periodStart, + periodEnd: userSubscription.periodEnd, + }) + expect(mockDbTransaction).not.toHaveBeenCalled() + expect(mockDbSelect).toHaveBeenCalledTimes(1) + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + }) + + it('calculates overage before opening the short user_stats transaction', async () => { + mockCalculateSubscriptionOverage.mockResolvedValue(250) + mockTxStatsLimit.mockResolvedValue([ + { + currentPeriodCost: '0', + proPeriodCostSnapshot: '0', + proPeriodCostSnapshotAt: null, + lastPeriodCost: '0', + billedOverageThisPeriod: '0', + creditBalance: '0', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockCalculateSubscriptionOverage).toHaveBeenCalled() + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockCalculateSubscriptionOverage.mock.invocationCallOrder[0]).toBeLessThan( + mockDbTransaction.mock.invocationCallOrder[0] + ) + expect(mockTxExecute).toHaveBeenCalledTimes(1) + expect(mockEnqueueOutboxEvent).toHaveBeenCalledTimes(1) + }) + + it('rechecks billed overage while locked before enqueueing an invoice', async () => { + mockCalculateSubscriptionOverage.mockResolvedValue(250) + mockTxStatsLimit.mockResolvedValue([ + { + currentPeriodCost: '0', + proPeriodCostSnapshot: '0', + proPeriodCostSnapshotAt: null, + lastPeriodCost: '0', + billedOverageThisPeriod: '200', + creditBalance: '0', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockTxExecute).toHaveBeenCalledTimes(1) + expect(mockTxUpdate).not.toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + }) + + it('skips personal threshold billing when locked usage inputs changed', async () => { + mockCalculateSubscriptionOverage.mockResolvedValue(250) + mockDbSelect + .mockImplementationOnce(() => buildPersonalSnapshotSelectChain({ currentPeriodCost: '250' })) + .mockImplementationOnce(() => buildPersonalSelectChain()) + mockTxStatsLimit.mockResolvedValue([ + { + currentPeriodCost: '0', + proPeriodCostSnapshot: '0', + proPeriodCostSnapshotAt: null, + lastPeriodCost: '250', + billedOverageThisPeriod: '0', + creditBalance: '0', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockTxUpdate).not.toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + }) + + it('computes organization overage before opening the locked transaction', async () => { + mockIsOrgScopedSubscription.mockReturnValue(true) + mockIsOrganizationBillingBlocked.mockResolvedValue(false) + mockGetOrganizationSubscriptionUsable.mockResolvedValue({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_team_1', + stripeCustomerId: 'cus_team_1', + }) + mockDbSelect.mockImplementationOnce(() => + buildSelectChain([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + ) + mockComputeOrgOverageAmount.mockResolvedValue({ + totalOverage: 250, + baseSubscriptionAmount: 100, + effectiveUsage: 350, + }) + mockTxStatsLimit + .mockResolvedValueOnce([{ userId: 'owner-1' }]) + .mockResolvedValueOnce([{ billedOverageThisPeriod: '0' }]) + .mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '25' }]) + .mockResolvedValueOnce([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockComputeOrgOverageAmount).toHaveBeenCalledWith({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + organizationId: userSubscription.referenceId, + pooledCurrentPeriodCost: 350, + departedMemberUsage: 25, + memberIds: ['owner-1'], + }) + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockComputeOrgOverageAmount.mock.invocationCallOrder[0]).toBeLessThan( + mockDbTransaction.mock.invocationCallOrder[0] + ) + expect(mockTxExecute).toHaveBeenCalledTimes(1) + expect(mockEnqueueOutboxEvent).toHaveBeenCalledTimes(1) + }) + + it('skips stale organization overage when locked usage inputs changed', async () => { + mockIsOrgScopedSubscription.mockReturnValue(true) + mockIsOrganizationBillingBlocked.mockResolvedValue(false) + mockGetOrganizationSubscriptionUsable.mockResolvedValue({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_team_1', + stripeCustomerId: 'cus_team_1', + }) + mockDbSelect.mockImplementationOnce(() => + buildSelectChain([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + ) + mockComputeOrgOverageAmount.mockResolvedValue({ + totalOverage: 250, + baseSubscriptionAmount: 100, + effectiveUsage: 350, + }) + mockTxStatsLimit + .mockResolvedValueOnce([{ userId: 'owner-1' }]) + .mockResolvedValueOnce([{ billedOverageThisPeriod: '0' }]) + .mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '75' }]) + .mockResolvedValueOnce([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '75', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + expect(mockTxUpdate).not.toHaveBeenCalled() + }) + + it('rechecks organization billed overage on the locked owner tracker', async () => { + mockIsOrgScopedSubscription.mockReturnValue(true) + mockIsOrganizationBillingBlocked.mockResolvedValue(false) + mockGetOrganizationSubscriptionUsable.mockResolvedValue({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_team_1', + stripeCustomerId: 'cus_team_1', + }) + mockDbSelect.mockImplementationOnce(() => + buildSelectChain([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + ) + mockComputeOrgOverageAmount.mockResolvedValue({ + totalOverage: 250, + baseSubscriptionAmount: 100, + effectiveUsage: 350, + }) + mockTxStatsLimit + .mockResolvedValueOnce([{ userId: 'owner-1' }]) + .mockResolvedValueOnce([{ billedOverageThisPeriod: '200' }]) + .mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '25' }]) + .mockResolvedValueOnce([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + expect(mockTxUpdate).not.toHaveBeenCalled() + }) + + it('skips stale organization overage when owner identity changed', async () => { + mockIsOrgScopedSubscription.mockReturnValue(true) + mockIsOrganizationBillingBlocked.mockResolvedValue(false) + mockGetOrganizationSubscriptionUsable.mockResolvedValue({ + plan: 'team', + seats: 2, + periodStart: new Date('2026-05-01T00:00:00.000Z'), + periodEnd: new Date('2026-06-01T00:00:00.000Z'), + stripeSubscriptionId: 'sub_team_1', + stripeCustomerId: 'cus_team_1', + }) + mockDbSelect.mockImplementationOnce(() => + buildSelectChain([ + { + userId: 'owner-1', + role: 'owner', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + { + userId: 'member-1', + role: 'member', + currentPeriodCost: '25', + departedMemberUsage: '25', + }, + ]) + ) + mockComputeOrgOverageAmount.mockResolvedValue({ + totalOverage: 250, + baseSubscriptionAmount: 100, + effectiveUsage: 350, + }) + mockTxStatsLimit + .mockResolvedValueOnce([{ userId: 'member-1' }]) + .mockResolvedValueOnce([{ billedOverageThisPeriod: '0' }]) + .mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '25' }]) + .mockResolvedValueOnce([ + { + userId: 'owner-1', + role: 'member', + currentPeriodCost: '350', + departedMemberUsage: '25', + }, + { + userId: 'member-1', + role: 'owner', + currentPeriodCost: '25', + departedMemberUsage: '25', + }, + ]) + + await checkAndBillOverageThreshold('user-1') + + expect(mockDbTransaction).toHaveBeenCalled() + expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled() + expect(mockTxUpdate).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/lib/billing/threshold-billing.ts b/apps/sim/lib/billing/threshold-billing.ts index 1219481b7d..86156b6adc 100644 --- a/apps/sim/lib/billing/threshold-billing.ts +++ b/apps/sim/lib/billing/threshold-billing.ts @@ -1,8 +1,8 @@ import { db } from '@sim/db' import { member, organization, subscription, userStats } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { eq, inArray, sql } from 'drizzle-orm' -import { DEFAULT_OVERAGE_THRESHOLD } from '@/lib/billing/constants' +import { and, eq, sql } from 'drizzle-orm' +import { BILLING_LOCK_TIMEOUT_MS, DEFAULT_OVERAGE_THRESHOLD } from '@/lib/billing/constants' import { getEffectiveBillingStatus, isOrganizationBillingBlocked } from '@/lib/billing/core/access' import { calculateSubscriptionOverage, computeOrgOverageAmount } from '@/lib/billing/core/billing' import { @@ -22,6 +22,22 @@ import { enqueueOutboxEvent } from '@/lib/core/outbox/service' const logger = createLogger('ThresholdBilling') const OVERAGE_THRESHOLD = envNumber(env.OVERAGE_THRESHOLD_DOLLARS, DEFAULT_OVERAGE_THRESHOLD) +const USAGE_TOTAL_EPSILON = 0.000001 + +interface PersonalUsageSnapshot { + currentPeriodCost: number + proPeriodCostSnapshot: number + proPeriodCostSnapshotAt: Date | null + lastPeriodCost: number +} + +interface OrganizationUsageSnapshot { + memberIds: string[] + ownerId: string + memberSignature: string + pooledCurrentPeriodCost: number + departedMemberUsage: number +} export async function checkAndBillOverageThreshold(userId: string): Promise { try { @@ -53,7 +69,57 @@ export async function checkAndBillOverageThreshold(userId: string): Promise { + await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${BILLING_LOCK_TIMEOUT_MS}ms'`)) + const statsRecords = await tx .select() .from(userStats) @@ -67,15 +133,16 @@ export async function checkAndBillOverageThreshold(userId: string): Promise ({ userId: m.userId, role: m.role })), + memberCount: memberUsageRows.length, + members: memberUsageRows.map((m) => ({ userId: m.userId, role: m.role })), }) - if (members.length === 0) { + if (memberUsageRows.length === 0) { logger.warn('No members found for organization', { organizationId }) return } - const owner = members.find((m) => m.role === 'owner') - if (!owner) { + const usageSnapshot = buildOrganizationUsageSnapshot(memberUsageRows) + if (!usageSnapshot) { logger.error( 'Organization has no owner when running threshold billing — data integrity issue, skipping', { organizationId } @@ -260,17 +312,80 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): logger.debug('Found organization owner, starting transaction', { organizationId, - ownerId: owner.userId, + ownerId: usageSnapshot.ownerId, + }) + + const { + totalOverage: currentOverage, + baseSubscriptionAmount: basePrice, + effectiveUsage: effectiveTeamUsage, + } = await computeOrgOverageAmount({ + plan: orgSubscription.plan, + seats: orgSubscription.seats ?? null, + periodStart: orgSubscription.periodStart ?? null, + periodEnd: orgSubscription.periodEnd ?? null, + organizationId, + pooledCurrentPeriodCost: usageSnapshot.pooledCurrentPeriodCost, + departedMemberUsage: usageSnapshot.departedMemberUsage, + memberIds: usageSnapshot.memberIds, }) + if (currentOverage < threshold) { + logger.debug('Organization threshold billing check below threshold before locking', { + organizationId, + totalTeamUsage: usageSnapshot.pooledCurrentPeriodCost + usageSnapshot.departedMemberUsage, + effectiveTeamUsage, + basePrice, + currentOverage, + threshold, + }) + return + } + + // Validate Stripe identifiers BEFORE mutating credits/trackers. + const stripeSubscriptionId = orgSubscription.stripeSubscriptionId + if (!stripeSubscriptionId) { + logger.error('No Stripe subscription ID for organization', { organizationId }) + return + } + + const customerId = orgSubscription.stripeCustomerId + if (!customerId) { + logger.error('No Stripe customer ID for organization', { organizationId }) + return + } + + const periodEnd = orgSubscription.periodEnd + ? Math.floor(orgSubscription.periodEnd.getTime() / 1000) + : Math.floor(Date.now() / 1000) + const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7) + const totalOverageCents = Math.round(currentOverage * 100) + await db.transaction(async (tx) => { - // Lock both owner stats and organization rows + await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${BILLING_LOCK_TIMEOUT_MS}ms'`)) + + const lockedOwnerRows = await tx + .select({ userId: member.userId }) + .from(member) + .where(and(eq(member.organizationId, organizationId), eq(member.role, 'owner'))) + .for('update') + .limit(1) + const lockedOwnerId = lockedOwnerRows[0]?.userId + if (!lockedOwnerId) { + logger.error('Organization owner not found after locking organization', { organizationId }) + return + } + const ownerStatsLock = await tx .select() .from(userStats) - .where(eq(userStats.userId, owner.userId)) + .where(eq(userStats.userId, lockedOwnerId)) .for('update') .limit(1) + if (ownerStatsLock.length === 0) { + logger.error('Owner stats not found', { organizationId, ownerId: lockedOwnerId }) + return + } const orgLock = await tx .select() @@ -279,58 +394,46 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): .for('update') .limit(1) - if (ownerStatsLock.length === 0) { - logger.error('Owner stats not found', { organizationId, ownerId: owner.userId }) + if (orgLock.length === 0) { + logger.error('Organization not found', { organizationId }) return } - if (orgLock.length === 0) { - logger.error('Organization not found', { organizationId }) + const lockedMemberUsageRows = await tx + .select({ + userId: member.userId, + role: member.role, + currentPeriodCost: userStats.currentPeriodCost, + departedMemberUsage: organization.departedMemberUsage, + }) + .from(member) + .leftJoin(userStats, eq(member.userId, userStats.userId)) + .innerJoin(organization, eq(organization.id, member.organizationId)) + .where(eq(member.organizationId, organizationId)) + + const lockedUsageSnapshot = buildOrganizationUsageSnapshot(lockedMemberUsageRows) + if ( + !lockedUsageSnapshot || + lockedOwnerId !== usageSnapshot.ownerId || + !organizationUsageSnapshotMatches(usageSnapshot, lockedUsageSnapshot) + ) { + logger.debug('Organization usage changed during threshold billing check; retry later', { + organizationId, + usageSnapshot, + lockedUsageSnapshot, + lockedOwnerId, + }) return } - let pooledCurrentPeriodCost = toNumber(toDecimal(ownerStatsLock[0].currentPeriodCost)) const totalBilledOverage = toNumber(toDecimal(ownerStatsLock[0].billedOverageThisPeriod)) const orgCreditBalance = toNumber(toDecimal(orgLock[0].creditBalance)) - const nonOwnerIds = members.filter((m) => m.userId !== owner.userId).map((m) => m.userId) - - if (nonOwnerIds.length > 0) { - const memberStatsRows = await tx - .select({ - userId: userStats.userId, - currentPeriodCost: userStats.currentPeriodCost, - }) - .from(userStats) - .where(inArray(userStats.userId, nonOwnerIds)) - - for (const stats of memberStatsRows) { - pooledCurrentPeriodCost += toNumber(toDecimal(stats.currentPeriodCost)) - } - } - - const departedMemberUsage = toNumber(toDecimal(orgLock[0].departedMemberUsage)) - - const { - totalOverage: currentOverage, - baseSubscriptionAmount: basePrice, - effectiveUsage: effectiveTeamUsage, - } = await computeOrgOverageAmount({ - plan: orgSubscription.plan, - seats: orgSubscription.seats ?? null, - periodStart: orgSubscription.periodStart ?? null, - periodEnd: orgSubscription.periodEnd ?? null, - organizationId, - pooledCurrentPeriodCost, - departedMemberUsage, - memberIds: members.map((m) => m.userId), - }) - const unbilledOverage = Math.max(0, currentOverage - totalBilledOverage) logger.debug('Organization threshold billing check', { organizationId, - totalTeamUsage: pooledCurrentPeriodCost + departedMemberUsage, + totalTeamUsage: usageSnapshot.pooledCurrentPeriodCost + usageSnapshot.departedMemberUsage, effectiveTeamUsage, basePrice, currentOverage, @@ -343,19 +446,6 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): return } - // Validate Stripe identifiers BEFORE mutating credits/trackers. - const stripeSubscriptionId = orgSubscription.stripeSubscriptionId - if (!stripeSubscriptionId) { - logger.error('No Stripe subscription ID for organization', { organizationId }) - return - } - - const customerId = orgSubscription.stripeCustomerId - if (!customerId) { - logger.error('No Stripe customer ID for organization', { organizationId }) - return - } - let amountToBill = unbilledOverage let creditsApplied = 0 @@ -384,7 +474,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): .set({ billedOverageThisPeriod: sql`${userStats.billedOverageThisPeriod} + ${unbilledOverage}`, }) - .where(eq(userStats.userId, owner.userId)) + .where(eq(userStats.userId, lockedOwnerId)) logger.info('Credits fully covered org threshold overage', { organizationId, @@ -394,12 +484,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): return } - const periodEnd = orgSubscription.periodEnd - ? Math.floor(orgSubscription.periodEnd.getTime() / 1000) - : Math.floor(Date.now() / 1000) - const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7) const amountCents = Math.round(amountToBill * 100) - const totalOverageCents = Math.round(currentOverage * 100) // Bump billed tracker and enqueue Stripe invoice atomically. // See user-path above for the full retry-invariant reasoning. @@ -408,7 +493,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): .set({ billedOverageThisPeriod: sql`${userStats.billedOverageThisPeriod} + ${unbilledOverage}`, }) - .where(eq(userStats.userId, owner.userId)) + .where(eq(userStats.userId, lockedOwnerId)) await enqueueOutboxEvent(tx, OUTBOX_EVENT_TYPES.STRIPE_THRESHOLD_OVERAGE_INVOICE, { customerId, @@ -430,7 +515,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): logger.info('Queued organization threshold overage invoice for Stripe', { organizationId, - ownerId: owner.userId, + ownerId: lockedOwnerId, creditsApplied, amountBilled: amountToBill, totalProcessed: unbilledOverage, @@ -444,3 +529,92 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string): }) } } + +async function getPersonalUsageSnapshot(userId: string): Promise { + const [stats] = await db + .select({ + currentPeriodCost: userStats.currentPeriodCost, + proPeriodCostSnapshot: userStats.proPeriodCostSnapshot, + proPeriodCostSnapshotAt: userStats.proPeriodCostSnapshotAt, + lastPeriodCost: userStats.lastPeriodCost, + }) + .from(userStats) + .where(eq(userStats.userId, userId)) + .limit(1) + + return stats ? personalUsageSnapshotFromStats(stats) : null +} + +function personalUsageSnapshotFromStats(stats: { + currentPeriodCost: string | number | null + proPeriodCostSnapshot: string | number | null + proPeriodCostSnapshotAt: Date | null + lastPeriodCost: string | number | null +}): PersonalUsageSnapshot { + return { + currentPeriodCost: toNumber(toDecimal(stats.currentPeriodCost)), + proPeriodCostSnapshot: toNumber(toDecimal(stats.proPeriodCostSnapshot)), + proPeriodCostSnapshotAt: stats.proPeriodCostSnapshotAt, + lastPeriodCost: toNumber(toDecimal(stats.lastPeriodCost)), + } +} + +function personalUsageSnapshotMatches( + expected: PersonalUsageSnapshot, + actual: PersonalUsageSnapshot +): boolean { + return ( + Math.abs(expected.currentPeriodCost - actual.currentPeriodCost) <= USAGE_TOTAL_EPSILON && + Math.abs(expected.proPeriodCostSnapshot - actual.proPeriodCostSnapshot) <= + USAGE_TOTAL_EPSILON && + Math.abs(expected.lastPeriodCost - actual.lastPeriodCost) <= USAGE_TOTAL_EPSILON && + nullableDateTime(expected.proPeriodCostSnapshotAt) === + nullableDateTime(actual.proPeriodCostSnapshotAt) + ) +} + +function buildOrganizationUsageSnapshot( + rows: { + userId: string + role: string + currentPeriodCost: string | number | null + departedMemberUsage: string | number | null + }[] +): OrganizationUsageSnapshot | null { + const owner = rows.find((row) => row.role === 'owner') + if (!owner) return null + + const sortedRows = [...rows].sort((a, b) => a.userId.localeCompare(b.userId)) + let pooledCurrentPeriodCost = 0 + for (const row of sortedRows) { + pooledCurrentPeriodCost += toNumber(toDecimal(row.currentPeriodCost)) + } + + return { + memberIds: sortedRows.map((row) => row.userId), + ownerId: owner.userId, + memberSignature: sortedRows + .map( + (row) => + `${row.userId}:${row.role}:${toNumber(toDecimal(row.currentPeriodCost)).toFixed(6)}` + ) + .join('|'), + pooledCurrentPeriodCost, + departedMemberUsage: toNumber(toDecimal(owner.departedMemberUsage)), + } +} + +function organizationUsageSnapshotMatches( + expected: OrganizationUsageSnapshot, + actual: OrganizationUsageSnapshot +): boolean { + return ( + expected.ownerId === actual.ownerId && + expected.memberSignature === actual.memberSignature && + Math.abs(expected.departedMemberUsage - actual.departedMemberUsage) <= USAGE_TOTAL_EPSILON + ) +} + +function nullableDateTime(value: Date | null): number | null { + return value?.getTime() ?? null +} diff --git a/apps/sim/lib/billing/webhooks/invoices.test.ts b/apps/sim/lib/billing/webhooks/invoices.test.ts index 9d601c33a8..eabf87666b 100644 --- a/apps/sim/lib/billing/webhooks/invoices.test.ts +++ b/apps/sim/lib/billing/webhooks/invoices.test.ts @@ -103,6 +103,7 @@ vi.mock('@react-email/render', () => ({ import { handleInvoicePaymentFailed, handleInvoicePaymentSucceeded, + resetUsageForSubscription, } from '@/lib/billing/webhooks/invoices' interface SelectResponse { @@ -127,6 +128,7 @@ function installSelectResponseQueue() { throw new Error('No queued db.select response') } const builder = { + for: vi.fn(() => builder), limit: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), orderBy: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), returning: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), @@ -223,4 +225,40 @@ describe('invoice billing recovery', () => { expect(mockUnblockOrgMembers).toHaveBeenCalledWith('org-1', 'payment_failed') expect(mockBlockOrgMembers).not.toHaveBeenCalled() }) + + it('coordinates org usage reset with owner tracker and organization locks', async () => { + queueSelectResponse({ limitResult: [{ userId: 'owner-1' }] }) + queueSelectResponse({ limitResult: [{ userId: 'owner-1' }] }) + queueSelectResponse({ limitResult: [{ id: 'org-1' }] }) + queueSelectResponse({ whereResult: [{ userId: 'owner-1' }, { userId: 'member-1' }] }) + queueSelectResponse({ + whereResult: [ + { userId: 'owner-1', current: '125', currentCopilot: '10' }, + { userId: 'member-1', current: '75', currentCopilot: '5' }, + ], + }) + queueSelectResponse({ whereResult: [] }) + queueSelectResponse({ whereResult: [] }) + + await resetUsageForSubscription({ plan: 'team', referenceId: 'org-1' }) + + expect(dbChainMockFns.transaction).toHaveBeenCalledTimes(1) + expect(dbChainMockFns.update).toHaveBeenCalledTimes(2) + expect(Object.keys(dbChainMockFns.select.mock.calls[0][0] ?? {})).toEqual(['userId']) + expect(Object.keys(dbChainMockFns.select.mock.calls[1][0] ?? {})).toEqual(['userId']) + expect(Object.keys(dbChainMockFns.select.mock.calls[2][0] ?? {})).toEqual(['id']) + + const statsReset = dbChainMockFns.set.mock.calls[0][0] as Record + expect(statsReset.currentPeriodCost).not.toBe('0') + expect(statsReset.currentPeriodCopilotCost).not.toBe('0') + expect(statsReset.lastPeriodCost).toMatchObject({ + toSQL: expect.any(Function), + }) + expect((statsReset.lastPeriodCost as { toSQL: () => { sql: string } }).toSQL().sql).toContain( + 'CASE' + ) + expect( + (statsReset.currentPeriodCost as { toSQL: () => { sql: string } }).toSQL().sql + ).toContain('GREATEST') + }) }) diff --git a/apps/sim/lib/billing/webhooks/invoices.ts b/apps/sim/lib/billing/webhooks/invoices.ts index bed1a7834e..f3f6bd4057 100644 --- a/apps/sim/lib/billing/webhooks/invoices.ts +++ b/apps/sim/lib/billing/webhooks/invoices.ts @@ -11,6 +11,7 @@ import { createLogger } from '@sim/logger' import { and, eq, inArray, isNull, ne, or, sql } from 'drizzle-orm' import type Stripe from 'stripe' import { getEmailSubject, PaymentFailedEmail, renderCreditPurchaseEmail } from '@/components/emails' +import { BILLING_LOCK_TIMEOUT_MS } from '@/lib/billing/constants' import { calculateSubscriptionOverage, isSubscriptionOrgScoped } from '@/lib/billing/core/billing' import { addCredits, getCreditBalanceForEntity } from '@/lib/billing/credits/balance' import { setUsageLimitForCredits } from '@/lib/billing/credits/purchase' @@ -388,40 +389,86 @@ export async function getBilledOverageForSubscription(sub: { export async function resetUsageForSubscription(sub: { plan: string | null; referenceId: string }) { if (await isSubscriptionOrgScoped(sub)) { - const membersRows = await db - .select({ userId: member.userId }) - .from(member) - .where(eq(member.organizationId, sub.referenceId)) + await db.transaction(async (tx) => { + await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${BILLING_LOCK_TIMEOUT_MS}ms'`)) - for (const m of membersRows) { - const currentStats = await db - .select({ - current: userStats.currentPeriodCost, - currentCopilot: userStats.currentPeriodCopilotCost, - }) - .from(userStats) - .where(eq(userStats.userId, m.userId)) + const ownerRows = await tx + .select({ userId: member.userId }) + .from(member) + .where(and(eq(member.organizationId, sub.referenceId), eq(member.role, 'owner'))) + .for('update') .limit(1) - if (currentStats.length > 0) { - const current = currentStats[0].current || '0' - const currentCopilot = currentStats[0].currentCopilot || '0' - await db + + const ownerId = ownerRows[0]?.userId + if (ownerId) { + await tx + .select({ userId: userStats.userId }) + .from(userStats) + .where(eq(userStats.userId, ownerId)) + .for('update') + .limit(1) + } + + await tx + .select({ id: organization.id }) + .from(organization) + .where(eq(organization.id, sub.referenceId)) + .for('update') + .limit(1) + + const membersRows = await tx + .select({ userId: member.userId }) + .from(member) + .where(eq(member.organizationId, sub.referenceId)) + + const memberIds = membersRows.map((row) => row.userId) + if (memberIds.length > 0) { + const memberStatsRows = await tx + .select({ + userId: userStats.userId, + current: userStats.currentPeriodCost, + currentCopilot: userStats.currentPeriodCopilotCost, + }) + .from(userStats) + .where(inArray(userStats.userId, memberIds)) + + const statsUserIds = memberStatsRows.map((row) => row.userId) + if (statsUserIds.length === 0) { + await tx + .update(organization) + .set({ departedMemberUsage: '0' }) + .where(eq(organization.id, sub.referenceId)) + return + } + + const currentCostByUser = sql.join( + memberStatsRows.map((row) => sql`WHEN ${row.userId} THEN ${row.current ?? '0'}`), + sql` ` + ) + const currentCopilotCostByUser = sql.join( + memberStatsRows.map((row) => sql`WHEN ${row.userId} THEN ${row.currentCopilot ?? '0'}`), + sql` ` + ) + const capturedCurrentCost = sql`CASE ${userStats.userId} ${currentCostByUser} ELSE '0' END` + const capturedCurrentCopilotCost = sql`CASE ${userStats.userId} ${currentCopilotCostByUser} ELSE '0' END` + + await tx .update(userStats) .set({ - lastPeriodCost: current, - lastPeriodCopilotCost: currentCopilot, - currentPeriodCost: sql`GREATEST(0, ${userStats.currentPeriodCost} - ${current}::decimal)`, - currentPeriodCopilotCost: sql`GREATEST(0, ${userStats.currentPeriodCopilotCost} - ${currentCopilot}::decimal)`, + lastPeriodCost: capturedCurrentCost, + lastPeriodCopilotCost: capturedCurrentCopilotCost, + currentPeriodCost: sql`GREATEST(0, ${userStats.currentPeriodCost} - (${capturedCurrentCost})::decimal)`, + currentPeriodCopilotCost: sql`GREATEST(0, ${userStats.currentPeriodCopilotCost} - (${capturedCurrentCopilotCost})::decimal)`, billedOverageThisPeriod: '0', }) - .where(eq(userStats.userId, m.userId)) + .where(inArray(userStats.userId, statsUserIds)) } - } - await db - .update(organization) - .set({ departedMemberUsage: '0' }) - .where(eq(organization.id, sub.referenceId)) + await tx + .update(organization) + .set({ departedMemberUsage: '0' }) + .where(eq(organization.id, sub.referenceId)) + }) } else { const currentStats = await db .select({ @@ -859,36 +906,29 @@ export async function handleInvoiceFinalized(event: Stripe.Event) { const entityType = (await isSubscriptionOrgScoped(sub)) ? 'organization' : 'user' const entityId = sub.referenceId - // Resolve the userStats row that holds the `billedOverageThisPeriod` - // tracker. Org subs: the owner's row. Personal: the user's own row. - // Throw if an org has no owner — returning early would cache a - // "successful" no-op, and the next cycle's tracker would still - // reflect this cycle's billed amount, breaking future overage math. - let trackerUserId: string - if (entityType === 'organization') { - const ownerRows = await db - .select({ userId: member.userId }) - .from(member) - .where(and(eq(member.organizationId, entityId), eq(member.role, 'owner'))) - .limit(1) - const ownerId = ownerRows[0]?.userId - if (!ownerId) { - throw new Error( - `Organization ${entityId} has no owner member; cannot process invoice finalization` - ) + // Phase 1 — atomic commit. Resolve org owners inside the transaction, + // then lock the tracker row so `billedOverageThisPeriod` is serialized + // against threshold billing, resets, owner transfers, and retries. + const phase1 = await db.transaction(async (tx) => { + await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${BILLING_LOCK_TIMEOUT_MS}ms'`)) + + let trackerUserId = entityId + if (entityType === 'organization') { + const ownerRows = await tx + .select({ userId: member.userId }) + .from(member) + .where(and(eq(member.organizationId, entityId), eq(member.role, 'owner'))) + .for('update') + .limit(1) + const ownerId = ownerRows[0]?.userId + if (!ownerId) { + throw new Error( + `Organization ${entityId} has no owner member; cannot process invoice finalization` + ) + } + trackerUserId = ownerId } - trackerUserId = ownerId - } else { - trackerUserId = entityId - } - // Phase 1 — atomic commit. Lock the tracker row first so we read - // `billedOverageThisPeriod` serialized against concurrent events; - // then read the credit balance, decrement it, and bump the - // tracker to `totalOverage`. On retry, the locked re-read sees - // `billed == totalOverage` → `remaining == 0` → credit removal - // skipped. That's the invariant preventing double-deduction. - const phase1 = await db.transaction(async (tx) => { const trackerRows = await tx .select({ billed: userStats.billedOverageThisPeriod }) .from(userStats)