diff --git a/README.md b/README.md index 74c42296..b66a84b5 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ Create a `.env` file and add OpenCommit config variables there like this: OCO_AI_PROVIDER= OCO_API_KEY= // or other LLM provider API token OCO_API_URL= +OCO_API_CUSTOM_HEADERS= OCO_TOKENS_MAX_INPUT= OCO_TOKENS_MAX_OUTPUT= OCO_DESCRIPTION= diff --git a/src/commands/config.ts b/src/commands/config.ts index 7e30cd59..e2dcb4f2 100644 --- a/src/commands/config.ts +++ b/src/commands/config.ts @@ -25,6 +25,7 @@ export enum CONFIG_KEYS { OCO_ONE_LINE_COMMIT = 'OCO_ONE_LINE_COMMIT', OCO_TEST_MOCK_TYPE = 'OCO_TEST_MOCK_TYPE', OCO_API_URL = 'OCO_API_URL', + OCO_API_CUSTOM_HEADERS = 'OCO_API_CUSTOM_HEADERS', OCO_OMIT_SCOPE = 'OCO_OMIT_SCOPE', OCO_GITPUSH = 'OCO_GITPUSH' // todo: deprecate } @@ -204,6 +205,22 @@ export const configValidators = { return value; }, + [CONFIG_KEYS.OCO_API_CUSTOM_HEADERS](value) { + try { + // Custom headers must be a valid JSON string + if (typeof value === 'string') { + JSON.parse(value); + } + return value; + } catch (error) { + validateConfig( + CONFIG_KEYS.OCO_API_CUSTOM_HEADERS, + false, + 'Must be a valid JSON string of headers' + ); + } + }, + [CONFIG_KEYS.OCO_TOKENS_MAX_INPUT](value: any) { value = parseInt(value); validateConfig( @@ -380,6 +397,7 @@ export type ConfigType = { [CONFIG_KEYS.OCO_TOKENS_MAX_INPUT]: number; [CONFIG_KEYS.OCO_TOKENS_MAX_OUTPUT]: number; [CONFIG_KEYS.OCO_API_URL]?: string; + [CONFIG_KEYS.OCO_API_CUSTOM_HEADERS]?: string; [CONFIG_KEYS.OCO_DESCRIPTION]: boolean; [CONFIG_KEYS.OCO_EMOJI]: boolean; [CONFIG_KEYS.OCO_WHY]: boolean; @@ -462,6 +480,7 @@ const getEnvConfig = (envPath: string) => { OCO_MODEL: process.env.OCO_MODEL, OCO_API_URL: process.env.OCO_API_URL, OCO_API_KEY: process.env.OCO_API_KEY, + OCO_API_CUSTOM_HEADERS: process.env.OCO_API_CUSTOM_HEADERS, OCO_AI_PROVIDER: process.env.OCO_AI_PROVIDER as OCO_AI_PROVIDER_ENUM, OCO_TOKENS_MAX_INPUT: parseConfigVarValue(process.env.OCO_TOKENS_MAX_INPUT), diff --git a/src/engine/Engine.ts b/src/engine/Engine.ts index 19562271..c5bd2e4b 100644 --- a/src/engine/Engine.ts +++ b/src/engine/Engine.ts @@ -11,6 +11,7 @@ export interface AiEngineConfig { maxTokensOutput: number; maxTokensInput: number; baseURL?: string; + customHeaders?: Record; } type Client = diff --git a/src/engine/ollama.ts b/src/engine/ollama.ts index 2d21d637..7d0355b0 100644 --- a/src/engine/ollama.ts +++ b/src/engine/ollama.ts @@ -11,11 +11,18 @@ export class OllamaEngine implements AiEngine { constructor(config) { this.config = config; + + // Combine base headers with custom headers + const headers = { + 'Content-Type': 'application/json', + ...config.customHeaders + }; + this.client = axios.create({ url: config.baseURL ? `${config.baseURL}/${config.apiKey}` : 'http://localhost:11434/api/chat', - headers: { 'Content-Type': 'application/json' } + headers }); } diff --git a/src/engine/openAi.ts b/src/engine/openAi.ts index 4e1c6a99..22a9b37e 100644 --- a/src/engine/openAi.ts +++ b/src/engine/openAi.ts @@ -1,6 +1,7 @@ import axios from 'axios'; import { OpenAI } from 'openai'; import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff'; +import { parseCustomHeaders } from '../utils/engine'; import { removeContentTags } from '../utils/removeContentTags'; import { tokenCount } from '../utils/tokenCount'; import { AiEngine, AiEngineConfig } from './Engine'; @@ -14,11 +15,22 @@ export class OpenAiEngine implements AiEngine { constructor(config: OpenAiConfig) { this.config = config; - if (!config.baseURL) { - this.client = new OpenAI({ apiKey: config.apiKey }); - } else { - this.client = new OpenAI({ apiKey: config.apiKey, baseURL: config.baseURL }); + const clientOptions: OpenAI.ClientOptions = { + apiKey: config.apiKey + }; + + if (config.baseURL) { + clientOptions.baseURL = config.baseURL; + } + + if (config.customHeaders) { + const headers = parseCustomHeaders(config.customHeaders); + if (Object.keys(headers).length > 0) { + clientOptions.defaultHeaders = headers; + } } + + this.client = new OpenAI(clientOptions); } public generateCommitMessage = async ( @@ -42,7 +54,7 @@ export class OpenAiEngine implements AiEngine { this.config.maxTokensInput - this.config.maxTokensOutput ) throw new Error(GenerateCommitMessageErrorEnum.tooMuchTokens); - + const completion = await this.client.chat.completions.create(params); const message = completion.choices[0].message; diff --git a/src/utils/engine.ts b/src/utils/engine.ts index 3137a05f..dbc45a00 100644 --- a/src/utils/engine.ts +++ b/src/utils/engine.ts @@ -12,16 +12,39 @@ import { GroqEngine } from '../engine/groq'; import { MLXEngine } from '../engine/mlx'; import { DeepseekEngine } from '../engine/deepseek'; +export function parseCustomHeaders(headers: any): Record { + let parsedHeaders = {}; + + if (!headers) { + return parsedHeaders; + } + + try { + if (typeof headers === 'object' && !Array.isArray(headers)) { + parsedHeaders = headers; + } else { + parsedHeaders = JSON.parse(headers); + } + } catch (error) { + console.warn('Invalid OCO_API_CUSTOM_HEADERS format, ignoring custom headers'); + } + + return parsedHeaders; +} + export function getEngine(): AiEngine { const config = getConfig(); const provider = config.OCO_AI_PROVIDER; + const customHeaders = parseCustomHeaders(config.OCO_API_CUSTOM_HEADERS); + const DEFAULT_CONFIG = { model: config.OCO_MODEL!, maxTokensOutput: config.OCO_TOKENS_MAX_OUTPUT!, maxTokensInput: config.OCO_TOKENS_MAX_INPUT!, baseURL: config.OCO_API_URL!, - apiKey: config.OCO_API_KEY! + apiKey: config.OCO_API_KEY!, + customHeaders }; switch (provider) { diff --git a/test/unit/config.test.ts b/test/unit/config.test.ts index 89ffc7e4..fc4709dc 100644 --- a/test/unit/config.test.ts +++ b/test/unit/config.test.ts @@ -122,6 +122,30 @@ describe('config', () => { expect(config.OCO_ONE_LINE_COMMIT).toEqual(false); expect(config.OCO_OMIT_SCOPE).toEqual(true); }); + + it('should handle custom HTTP headers correctly', async () => { + globalConfigFile = await generateConfig('.opencommit', { + OCO_API_CUSTOM_HEADERS: '{"X-Global-Header": "global-value"}' + }); + + envConfigFile = await generateConfig('.env', { + OCO_API_CUSTOM_HEADERS: '{"Authorization": "Bearer token123", "X-Custom-Header": "test-value"}' + }); + + const config = getConfig({ + globalPath: globalConfigFile.filePath, + envPath: envConfigFile.filePath + }); + + expect(config).not.toEqual(null); + expect(config.OCO_API_CUSTOM_HEADERS).toEqual({"Authorization": "Bearer token123", "X-Custom-Header": "test-value"}); + + // No need to parse JSON again since it's already an object + const parsedHeaders = config.OCO_API_CUSTOM_HEADERS; + expect(parsedHeaders).toHaveProperty('Authorization', 'Bearer token123'); + expect(parsedHeaders).toHaveProperty('X-Custom-Header', 'test-value'); + expect(parsedHeaders).not.toHaveProperty('X-Global-Header'); + }); it('should handle empty local config correctly', async () => { globalConfigFile = await generateConfig('.opencommit', {