diff --git a/server/langchain-tools.js b/server/langchain-tools.js index 4cb77313..3ca1128d 100644 --- a/server/langchain-tools.js +++ b/server/langchain-tools.js @@ -2,9 +2,20 @@ import { DynamicStructuredTool } from '@langchain/core/tools'; import { z } from 'zod'; import axios from 'axios'; -const BASE_URL = 'http://localhost:3000/api'; +const BASE_URL = process.env.SWAPI_BASE_URL || 'http://localhost:3000/api'; -// Helper function to create detailed API call information +// Constants for search functionality +const SEARCH_COLORS = ['red', 'blue', 'green', 'yellow', 'brown', 'black', 'white', 'orange', 'purple', 'pink', 'blond', 'blonde', 'fair', 'dark', 'light']; + +/** + * Helper function to create detailed API call information for tracking and debugging + * @param {string} method - HTTP method used for the API call + * @param {string} url - The URL that was called + * @param {Object} params - Query parameters sent with the request + * @param {Object} data - Request body data (if any) + * @param {Error} error - Error object if the request failed + * @returns {Object} Structured API call information object + */ function createApiCallInfo(method, url, params, data, error) { return { method, @@ -16,7 +27,12 @@ function createApiCallInfo(method, url, params, data, error) { }; } -// Wrapper function to capture API call details +/** + * Wrapper function to capture API call details and handle errors consistently + * @param {string} url - The API endpoint URL to call + * @param {Object} params - Query parameters to include (default: {}) + * @returns {Promise} Object with success flag, data, and API call information + */ async function makeApiCall(url, params = {}) { const apiCallInfo = createApiCallInfo('GET', url, params); try { @@ -47,6 +63,14 @@ export const getCharacterTool = new DynamicStructuredTool({ id: z.number().describe('The ID of the character to retrieve') }), func: async ({ id }) => { + // Basic input validation + if (!Number.isInteger(id) || id <= 0) { + return JSON.stringify({ + error: `Invalid character ID: ${id}. ID must be a positive integer.`, + apiCall: createApiCallInfo('GET', `${BASE_URL}/characters/${id}`, {}, null, new Error('Invalid ID')) + }, null, 2); + } + const result = await makeApiCall(`${BASE_URL}/characters/${id}`); if (result.success) { // Include API call information in the response for educational purposes @@ -332,21 +356,21 @@ export const searchCharactersTool = new DynamicStructuredTool({ let matchedCharacters = []; if (attributeLower.includes('eye')) { - const color = attributeLower.match(/(red|blue|green|yellow|brown|black|white|orange|purple|pink)/)?.[1]; + const color = attributeLower.match(new RegExp(`(${SEARCH_COLORS.join('|')})`))?.[1]; if (color) { matchedCharacters = characters.filter(char => char.eye_color && char.eye_color.toLowerCase().includes(color) ); } } else if (attributeLower.includes('hair')) { - const color = attributeLower.match(/(red|blue|green|yellow|brown|black|white|orange|purple|pink|blond|blonde)/)?.[1]; + const color = attributeLower.match(new RegExp(`(${SEARCH_COLORS.join('|')})`))?.[1]; if (color) { matchedCharacters = characters.filter(char => char.hair_color && char.hair_color.toLowerCase().includes(color) ); } } else if (attributeLower.includes('skin')) { - const color = attributeLower.match(/(red|blue|green|yellow|brown|black|white|orange|purple|pink|fair|dark|light)/)?.[1]; + const color = attributeLower.match(new RegExp(`(${SEARCH_COLORS.join('|')})`))?.[1]; if (color) { matchedCharacters = characters.filter(char => char.skin_color && char.skin_color.toLowerCase().includes(color) @@ -402,7 +426,11 @@ export const swapiTools = [ searchCharactersTool ]; -// Helper function to extract entity IDs from vector search results +/** + * Helper function to extract entity IDs from vector search results + * @param {Array} context - Array of vector search result objects with metadata + * @returns {Object} Object containing arrays of entity IDs grouped by entity type + */ export function extractEntityIds(context) { const entityIds = { characters: [], diff --git a/server/tests/langchain-tools-validation.test.js b/server/tests/langchain-tools-validation.test.js new file mode 100644 index 00000000..6d580340 --- /dev/null +++ b/server/tests/langchain-tools-validation.test.js @@ -0,0 +1,42 @@ +import { jest } from '@jest/globals'; +import { getCharacterTool } from '../langchain-tools.js'; + +describe('LangChain Tools Input Validation', () => { + describe('getCharacterTool input validation', () => { + it('should reject invalid ID (negative number)', async () => { + const result = await getCharacterTool.func({ id: -1 }); + const parsed = JSON.parse(result); + + expect(parsed.error).toBeDefined(); + expect(parsed.error).toContain('Invalid character ID'); + expect(parsed.error).toContain('-1'); + }); + + it('should reject invalid ID (zero)', async () => { + const result = await getCharacterTool.func({ id: 0 }); + const parsed = JSON.parse(result); + + expect(parsed.error).toBeDefined(); + expect(parsed.error).toContain('Invalid character ID'); + expect(parsed.error).toContain('0'); + }); + + it('should reject invalid ID (non-integer)', async () => { + const result = await getCharacterTool.func({ id: 1.5 }); + const parsed = JSON.parse(result); + + expect(parsed.error).toBeDefined(); + expect(parsed.error).toContain('Invalid character ID'); + expect(parsed.error).toContain('1.5'); + }); + + it('should include API call info for invalid IDs', async () => { + const result = await getCharacterTool.func({ id: -1 }); + const parsed = JSON.parse(result); + + expect(parsed.apiCall).toBeDefined(); + expect(parsed.apiCall.error).toContain('Invalid ID'); + expect(parsed.apiCall.method).toBe('GET'); + }); + }); +}); \ No newline at end of file