Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion packages/api-client/src/fetcher.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ describe("buildApiFetcher", () => {
};
return response;
};

beforeEach(() => {
vi.resetAllMocks();
vi.stubGlobal("fetch", mockFetch);
Expand Down
1 change: 0 additions & 1 deletion packages/api-client/src/fetcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export const buildApiFetcher: (
config: ApiFetcherConfig,
) => Parameters<typeof createApiClient>[0] = (config) => {
const userAgent = `posthog/desktop.hog.dev; version: ${config.appVersion}`;

const makeRequest = async (
input: Parameters<Parameters<typeof createApiClient>[0]["fetch"]>[0],
token: string,
Expand Down
43 changes: 43 additions & 0 deletions packages/api-client/src/posthog-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,49 @@ describe("PostHogAPIClient", () => {
);
});

it("returns the redirect URL when authorizing an MCP installation", async () => {
const fetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
json: async () => ({
redirect_url: "https://auth.example.com/authorize?state=abc",
}),
});
const client = new PostHogAPIClient(
"http://localhost:8000",
async () => "token",
async () => "token",
123,
);

(
client as unknown as {
api: { baseUrl: string; fetcher: { fetch: typeof fetch } };
}
).api = {
baseUrl: "http://localhost:8000",
fetcher: { fetch },
};

await expect(
client.authorizeMcpInstallation({
installation_id: "inst-123",
install_source: "posthog-code",
posthog_code_callback_url: "posthog-code://mcp-oauth-complete",
}),
).resolves.toEqual({
redirect_url: "https://auth.example.com/authorize?state=abc",
});

expect(fetch).toHaveBeenCalledWith(
expect.objectContaining({
method: "get",
path: "/api/environments/123/mcp_server_installations/authorize/",
}),
);
expect(fetch.mock.calls[0][0]).not.toHaveProperty("overrides");
});

describe("warmTask", () => {
function makeClient(fetch: ReturnType<typeof vi.fn>) {
const client = new PostHogAPIClient(
Expand Down
66 changes: 66 additions & 0 deletions packages/core/src/auth/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,72 @@ describe("AuthService", () => {
expect(service.getState().status).toBe("restoring");
expect(oauthFlow.refreshToken).toHaveBeenCalledTimes(3);
});

it("uses the current access token when a preemptive refresh fails before expiry", async () => {
vi.useFakeTimers();
try {
oauthFlow.startFlow.mockResolvedValue(
mockTokenResponse({
accessToken: "current-access-token",
refreshToken: "current-refresh-token",
}),
);
stubAuthFetch();

await service.initialize();
await service.login("us");

oauthFlow.refreshToken.mockReset();
oauthFlow.refreshToken.mockResolvedValue({
success: false,
error: "Token refresh failed: 500 Internal Server Error",
errorCode: "server_error",
});

await vi.advanceTimersByTimeAsync(3_599_500);

await expect(service.getValidAccessToken()).resolves.toMatchObject({
accessToken: "current-access-token",
});
expect(oauthFlow.refreshToken).toHaveBeenCalledTimes(3);
expect(service.getState().status).toBe("authenticated");
} finally {
vi.useRealTimers();
}
});

it("does not use the current access token when refresh token auth fails", async () => {
vi.useFakeTimers();
try {
oauthFlow.startFlow.mockResolvedValue(
mockTokenResponse({
accessToken: "current-access-token",
refreshToken: "current-refresh-token",
}),
);
stubAuthFetch();

await service.initialize();
await service.login("us");

oauthFlow.refreshToken.mockReset();
oauthFlow.refreshToken.mockResolvedValue({
success: false,
error: "Token revoked",
errorCode: "auth_error",
});

await vi.advanceTimersByTimeAsync(3_599_500);

await expect(service.getValidAccessToken()).rejects.toThrow(
"Token revoked",
);
expect(service.getState().status).toBe("anonymous");
expect(sessionPort.getCurrent()).toBeNull();
} finally {
vi.useRealTimers();
}
});
});

describe("transient org fetch failures", () => {
Expand Down
29 changes: 25 additions & 4 deletions packages/core/src/auth/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -487,12 +487,13 @@ export class AuthService extends TypedEventEmitter<AuthServiceEvents> {
private async ensureValidSession(
forceRefresh = false,
): Promise<InMemorySession> {
const currentSession = this.session;
if (
this.session &&
currentSession &&
!forceRefresh &&
!this.isSessionExpiring(this.session)
!this.isSessionExpiring(currentSession)
) {
return this.session;
return currentSession;
}

if (this.refreshPromise) {
Expand All @@ -502,7 +503,24 @@ export class AuthService extends TypedEventEmitter<AuthServiceEvents> {
const sessionInput = this.getSessionInputForRefresh();

const refreshAndSync = async (): Promise<InMemorySession> => {
const session = await this.refreshSession(sessionInput);
let session: InMemorySession;
try {
session = await this.refreshSession(sessionInput);
} catch (error) {
if (
currentSession &&
this.session === currentSession &&
!forceRefresh &&
!this.isSessionExpired(currentSession)
) {
this.logger.warn(
"Preemptive session refresh failed; using current access token",
{ error },
);
return currentSession;
}
throw error;
}
await this.syncAuthenticatedSession(session);
return session;
};
Expand Down Expand Up @@ -833,6 +851,9 @@ export class AuthService extends TypedEventEmitter<AuthServiceEvents> {
private isSessionExpiring(session: InMemorySession): boolean {
return session.accessTokenExpiresAt - Date.now() <= TOKEN_EXPIRY_SKEW_MS;
}
private isSessionExpired(session: InMemorySession): boolean {
return session.accessTokenExpiresAt <= Date.now();
}
private async fetchUserContext(
accessToken: string,
cloudRegion: CloudRegion,
Expand Down
112 changes: 112 additions & 0 deletions packages/ui/src/features/mcp-servers/hooks/useMcpServers.test.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import type { McpRecommendedServer } from "@posthog/api-client/posthog-client";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { act, renderHook, waitFor } from "@testing-library/react";
import type { ReactNode } from "react";
import { beforeEach, describe, expect, it, vi } from "vitest";

const mockClient = vi.hoisted(() => ({
getMcpServerInstallations: vi.fn(),
getMcpServers: vi.fn(),
installMcpTemplate: vi.fn(),
installCustomMcpServer: vi.fn(),
uninstallMcpServer: vi.fn(),
updateMcpServerInstallation: vi.fn(),
authorizeMcpInstallation: vi.fn(),
}));

const mockTrpcClient = vi.hoisted(() => ({
mcpCallback: {
getCallbackUrl: { query: vi.fn() },
openAndWaitForCallback: { mutate: vi.fn() },
},
}));

const mockTrpc = vi.hoisted(() => ({
mcpCallback: {
onOAuthComplete: {
subscriptionOptions: vi.fn(() => ({})),
},
},
}));

vi.mock("@posthog/ui/features/auth/authClient", () => ({
useOptionalAuthenticatedClient: () => mockClient,
}));

vi.mock("@posthog/host-router/react", () => ({
useHostTRPC: () => mockTrpc,
useHostTRPCClient: () => mockTrpcClient,
}));

vi.mock("@trpc/tanstack-react-query", () => ({
useSubscription: vi.fn(),
}));

vi.mock("sonner", () => ({
toast: {
error: vi.fn(),
success: vi.fn(),
},
}));

import { useMcpServers } from "./useMcpServers";

function wrapper({ children }: { children: ReactNode }) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
mutations: { retry: false },
},
});
return (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
}

const template = {
id: "granola",
name: "Granola",
auth_type: "oauth",
} as McpRecommendedServer;

describe("useMcpServers", () => {
beforeEach(() => {
vi.clearAllMocks();
mockClient.getMcpServerInstallations.mockResolvedValue([]);
mockClient.getMcpServers.mockResolvedValue([]);
mockTrpcClient.mcpCallback.getCallbackUrl.query.mockResolvedValue({
callbackUrl: "posthog-code://mcp-oauth-complete",
});
});

it("reverts template connect loading state after a failed install", async () => {
let rejectInstall!: (error: Error) => void;
mockClient.installMcpTemplate.mockReturnValue(
new Promise((_resolve, reject) => {
rejectInstall = reject;
}),
);

const { result } = renderHook(() => useMcpServers(), { wrapper });

act(() => {
result.current.installTemplate(template);
});

await waitFor(() => expect(result.current.installingId).toBe("granola"));
await waitFor(() =>
expect(mockClient.installMcpTemplate).toHaveBeenCalledWith({
template_id: "granola",
install_source: "posthog-code",
posthog_code_callback_url: "posthog-code://mcp-oauth-complete",
api_key: undefined,
}),
);

await act(async () => {
rejectInstall(new Error("Connection failed"));
});

await waitFor(() => expect(result.current.installingId).toBeNull());
});
});
10 changes: 5 additions & 5 deletions packages/ui/src/features/mcp-servers/hooks/useMcpServers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { useAuthenticatedMutation } from "@posthog/ui/hooks/useAuthenticatedMuta
import { useAuthenticatedQuery } from "@posthog/ui/hooks/useAuthenticatedQuery";
import { useQueryClient } from "@tanstack/react-query";
import { useSubscription } from "@trpc/tanstack-react-query";
import { useCallback, useMemo, useState } from "react";
import { useCallback, useMemo } from "react";
import { toast } from "sonner";

export const mcpKeys = {
Expand All @@ -38,7 +38,6 @@ export function useMcpServers() {
const trpc = useHostTRPC();
const trpcClient = useHostTRPCClient();
const oauth = useMemo(() => createOAuthCallback(trpcClient), [trpcClient]);
const [installingId, setInstallingId] = useState<string | null>(null);
const queryClient = useQueryClient();

const { data: installations, isLoading: installationsLoading } =
Expand Down Expand Up @@ -120,18 +119,15 @@ export function useMcpServers() {
toast.error(data.error);
}
invalidateInstallations();
setInstallingId(null);
},
onError: (error: Error) => {
toast.error(error.message || "Failed to connect server");
setInstallingId(null);
},
},
);

const installTemplate = useCallback(
(template: McpRecommendedServer, opts?: { api_key?: string }) => {
setInstallingId(template.id);
installTemplateMutation.mutate({
template_id: template.id,
api_key: opts?.api_key,
Expand All @@ -140,6 +136,10 @@ export function useMcpServers() {
[installTemplateMutation],
);

const installingId = installTemplateMutation.isPending
? (installTemplateMutation.variables?.template_id ?? null)
: null;

const installCustomMutation = useAuthenticatedMutation(
(
client,
Expand Down
Loading