Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions apps/mobile/src/features/auth/lib/oauth.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";

vi.mock("expo-auth-session", () => ({
makeRedirectUri: () => "posthog://callback",
AuthRequest: class {},
}));

vi.mock("expo-web-browser", () => ({
maybeCompleteAuthSession: () => {},
}));

import { refreshAccessToken, TokenRefreshError } from "./oauth";

const originalFetch = global.fetch;

function mockResponse(status: number, body: unknown): Response {
return new Response(JSON.stringify(body), {
status,
statusText: `Status ${status}`,
headers: { "Content-Type": "application/json" },
});
}

describe("refreshAccessToken", () => {
beforeEach(() => {
global.fetch = vi.fn() as unknown as typeof fetch;
});

afterEach(() => {
global.fetch = originalFetch;
});

it("returns the parsed token response on success", async () => {
vi.mocked(global.fetch).mockResolvedValueOnce(
mockResponse(200, { access_token: "fresh", expires_in: 3600 }),
);

const result = await refreshAccessToken("refresh", "us");

expect(result.access_token).toBe("fresh");
});

it.each([
{ name: "401", status: 401, body: {} },
{ name: "403", status: 403, body: {} },
{
name: "400 invalid_grant",
status: 400,
body: { error: "invalid_grant" },
},
{
name: "400 invalid_token",
status: 400,
body: { error: "invalid_token" },
},
])("classifies $name as auth_error", async ({ status, body }) => {
vi.mocked(global.fetch).mockResolvedValueOnce(mockResponse(status, body));

await expect(refreshAccessToken("refresh", "us")).rejects.toMatchObject({
errorCode: "auth_error",
});
});

it.each([
{ name: "invalid_client", body: { error: "invalid_client" } },
{ name: "invalid_request", body: { error: "invalid_request" } },
])("classifies a 400 $name as unknown_error", async ({ body }) => {
vi.mocked(global.fetch).mockResolvedValueOnce(mockResponse(400, body));

await expect(refreshAccessToken("refresh", "us")).rejects.toMatchObject({
errorCode: "unknown_error",
});
});

it.each([500, 502, 503])(
"classifies a %i as server_error",
async (status) => {
vi.mocked(global.fetch).mockResolvedValueOnce(mockResponse(status, {}));

await expect(refreshAccessToken("refresh", "us")).rejects.toMatchObject({
errorCode: "server_error",
});
},
);

it("classifies a thrown fetch as network_error", async () => {
vi.mocked(global.fetch).mockRejectedValueOnce(new Error("offline"));

const error = await refreshAccessToken("refresh", "us").catch((e) => e);

expect(error).toBeInstanceOf(TokenRefreshError);
expect(error.errorCode).toBe("network_error");
});
});
76 changes: 64 additions & 12 deletions apps/mobile/src/features/auth/lib/oauth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,78 @@ export async function exchangeCodeForToken(
return response.json();
}

export type OAuthRefreshErrorCode =
| "auth_error"
| "server_error"
| "network_error"
| "unknown_error";

export class TokenRefreshError extends Error {
readonly errorCode: OAuthRefreshErrorCode;

constructor(errorCode: OAuthRefreshErrorCode, message: string) {
super(message);
this.name = "TokenRefreshError";
this.errorCode = errorCode;
}
}

async function parseOAuthErrorCode(response: Response): Promise<string | null> {
try {
const body = (await response.json()) as { error?: unknown };
return typeof body.error === "string" ? body.error : null;
} catch {
return null;
}
}

export async function refreshAccessToken(
refreshToken: string,
region: CloudRegion,
): Promise<OAuthTokenResponse> {
const cloudUrl = getCloudUrlFromRegion(region);

const response = await fetch(`${cloudUrl}/oauth/token`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "refresh_token",
refresh_token: refreshToken,
client_id: getOauthClientIdFromRegion(region),
}),
});
let response: Response;
try {
response = await fetch(`${cloudUrl}/oauth/token`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "refresh_token",
refresh_token: refreshToken,
client_id: getOauthClientIdFromRegion(region),
}),
});
} catch (error) {
throw new TokenRefreshError(
"network_error",
error instanceof Error ? error.message : "Token refresh network error",
);
}

if (!response.ok) {
throw new Error(`Token refresh failed: ${response.statusText}`);
// 401/403 are always auth failures. A 400 only means a dead refresh token
// when the OAuth error is invalid_grant/invalid_token; other 400s like
// invalid_client are config bugs that must not sign the user out, or they
// could never log back in with the same broken config.
const oauthErrorCode =
response.status === 400 ? await parseOAuthErrorCode(response) : null;
const isAuthError =
response.status === 401 ||
response.status === 403 ||
oauthErrorCode === "invalid_grant" ||
oauthErrorCode === "invalid_token";
const errorCode: OAuthRefreshErrorCode = isAuthError
? "auth_error"
: response.status >= 500
? "server_error"
: "unknown_error";
throw new TokenRefreshError(
errorCode,
`Token refresh failed: ${response.status} ${response.statusText}`,
);
}

return response.json();
Expand Down
66 changes: 66 additions & 0 deletions apps/mobile/src/features/auth/stores/authStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ vi.mock("@react-native-async-storage/async-storage", () => ({
vi.mock("../lib/oauth", () => ({
performOAuthFlow: mockPerformOAuthFlow,
refreshAccessToken: mockRefreshAccessTokenRequest,
TokenRefreshError: class TokenRefreshError extends Error {
readonly errorCode: string;
constructor(errorCode: string, message: string) {
super(message);
this.errorCode = errorCode;
}
},
}));

vi.mock("../lib/secureStorage", () => ({
Expand Down Expand Up @@ -63,8 +70,20 @@ vi.mock("@/lib/queryClient", () => ({
}));

import { OAUTH_SCOPE_VERSION } from "../lib/constants";
import { TokenRefreshError } from "../lib/oauth";
import { useAuthStore } from "./authStore";

function expiredStoredTokens() {
return {
accessToken: "old-token",
refreshToken: "old-refresh",
expiresAt: Date.now() - 1_000,
cloudRegion: "us" as const,
scopedTeams: [42],
scopeVersion: OAUTH_SCOPE_VERSION,
};
}

describe("authStore", () => {
beforeEach(() => {
mockPerformOAuthFlow.mockReset();
Expand Down Expand Up @@ -135,4 +154,51 @@ describe("authStore", () => {
isLoading: false,
});
});

it("signs out when refreshing an expired token is rejected as auth_error", async () => {
mockGetTokens.mockResolvedValueOnce(expiredStoredTokens());
mockRefreshAccessTokenRequest.mockRejectedValueOnce(
new TokenRefreshError("auth_error", "invalid_grant"),
);

const initialized = await useAuthStore.getState().initializeAuth();

expect(initialized).toBe(false);
expect(mockDeleteTokens).toHaveBeenCalledOnce();
expect(useAuthStore.getState()).toMatchObject({
isAuthenticated: false,
isLoading: false,
});
});

it.each([
{
name: "server_error",
error: new TokenRefreshError("server_error", "5xx"),
},
{
name: "network_error",
error: new TokenRefreshError("network_error", "offline"),
},
{
name: "unknown_error (config 400)",
error: new TokenRefreshError("unknown_error", "invalid_client"),
},
])(
"keeps the session when an expired-token refresh fails with $name",
async ({ error }) => {
mockGetTokens.mockResolvedValueOnce(expiredStoredTokens());
mockRefreshAccessTokenRequest.mockRejectedValueOnce(error);

const initialized = await useAuthStore.getState().initializeAuth();

expect(initialized).toBe(true);
expect(mockDeleteTokens).not.toHaveBeenCalled();
expect(useAuthStore.getState()).toMatchObject({
oauthAccessToken: "old-token",
isAuthenticated: true,
isLoading: false,
});
},
);
});
63 changes: 30 additions & 33 deletions apps/mobile/src/features/auth/stores/authStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
import {
performOAuthFlow,
refreshAccessToken as refreshAccessTokenRequest,
TokenRefreshError,
} from "../lib/oauth";
import { deleteTokens, getTokens, saveTokens } from "../lib/secureStorage";
import type { CloudRegion, StoredTokens } from "../types";
Expand Down Expand Up @@ -78,6 +79,20 @@ function resolveActiveProjectId(
return scopedTeams[0] ?? null;
}

function isDeadRefreshToken(error: unknown): boolean {
return error instanceof TokenRefreshError && error.errorCode === "auth_error";
}

const CLEARED_AUTH_STATE = {
oauthAccessToken: null,
oauthRefreshToken: null,
tokenExpiry: null,
cloudRegion: null,
projectId: null,
scopedTeams: [],
isAuthenticated: false,
} satisfies Partial<AuthState>;

function maybeRegisterPushToken(): void {
if (!usePreferencesStore.getState().pushNotificationsEnabled) return;
usePushTokenStore
Expand Down Expand Up @@ -282,16 +297,7 @@ export const useAuthStore = create<AuthState>()(
if (tokens.scopeVersion !== OAUTH_SCOPE_VERSION) {
await deleteTokens();
queryClient.clear();
set({
oauthAccessToken: null,
oauthRefreshToken: null,
tokenExpiry: null,
cloudRegion: null,
projectId: null,
scopedTeams: [],
isLoading: false,
isAuthenticated: false,
});
set({ ...CLEARED_AUTH_STATE, isLoading: false });
return false;
}

Expand Down Expand Up @@ -320,20 +326,19 @@ export const useAuthStore = create<AuthState>()(
try {
await get().refreshAccessToken();
} catch (error) {
logger.error("Failed to refresh expired token:", error);
await deleteTokens();
queryClient.clear();
set({
oauthAccessToken: null,
oauthRefreshToken: null,
tokenExpiry: null,
cloudRegion: null,
projectId: null,
scopedTeams: [],
isLoading: false,
isAuthenticated: false,
});
return false;
if (isDeadRefreshToken(error)) {
logger.error("Refresh token rejected on startup; signing out");
await deleteTokens();
queryClient.clear();
set({ ...CLEARED_AUTH_STATE, isLoading: false });
return false;
}
// Transient (network/server) or config failure: keep the stored
// session so the next request's authedFetch retry can recover.
logger.warn(
"Token refresh failed transiently on startup; keeping session",
error,
);
}
}

Expand Down Expand Up @@ -362,15 +367,7 @@ export const useAuthStore = create<AuthState>()(
// Clear React Query cache to prevent data leakage between sessions
queryClient.clear();

set({
oauthAccessToken: null,
oauthRefreshToken: null,
tokenExpiry: null,
cloudRegion: null,
projectId: null,
scopedTeams: [],
isAuthenticated: false,
});
set(CLEARED_AUTH_STATE);
},
}),
{
Expand Down
Loading