Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions server/langchain-tools.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@ import { DynamicStructuredTool } from '@langchain/core/tools';
import { z } from 'zod';
import axios from 'axios';

const BASE_URL = 'http://localhost:3000/api';
const BASE_URL = process.env.SWAPI_BASE_URL || 'http://localhost:3000/api';

// Helper function to create detailed API call information
// Constants for search functionality
const SEARCH_COLORS = ['red', 'blue', 'green', 'yellow', 'brown', 'black', 'white', 'orange', 'purple', 'pink', 'blond', 'blonde', 'fair', 'dark', 'light'];

/**
* Helper function to create detailed API call information for tracking and debugging
* @param {string} method - HTTP method used for the API call
* @param {string} url - The URL that was called
* @param {Object} params - Query parameters sent with the request
* @param {Object} data - Request body data (if any)
* @param {Error} error - Error object if the request failed
* @returns {Object} Structured API call information object
*/
function createApiCallInfo(method, url, params, data, error) {
return {
method,
Expand All @@ -16,7 +27,12 @@ function createApiCallInfo(method, url, params, data, error) {
};
}

// Wrapper function to capture API call details
/**
* Wrapper function to capture API call details and handle errors consistently
* @param {string} url - The API endpoint URL to call
* @param {Object} params - Query parameters to include (default: {})
* @returns {Promise<Object>} Object with success flag, data, and API call information
*/
async function makeApiCall(url, params = {}) {
const apiCallInfo = createApiCallInfo('GET', url, params);
try {
Expand Down Expand Up @@ -47,6 +63,14 @@ export const getCharacterTool = new DynamicStructuredTool({
id: z.number().describe('The ID of the character to retrieve')
}),
func: async ({ id }) => {
// Basic input validation
if (!Number.isInteger(id) || id <= 0) {
return JSON.stringify({
error: `Invalid character ID: ${id}. ID must be a positive integer.`,
apiCall: createApiCallInfo('GET', `${BASE_URL}/characters/${id}`, {}, null, new Error('Invalid ID'))
}, null, 2);
}

const result = await makeApiCall(`${BASE_URL}/characters/${id}`);
if (result.success) {
// Include API call information in the response for educational purposes
Expand Down Expand Up @@ -332,21 +356,21 @@ export const searchCharactersTool = new DynamicStructuredTool({
let matchedCharacters = [];

if (attributeLower.includes('eye')) {
const color = attributeLower.match(/(red|blue|green|yellow|brown|black|white|orange|purple|pink)/)?.[1];
const color = attributeLower.match(new RegExp(`(${SEARCH_COLORS.join('|')})`))?.[1];
if (color) {
matchedCharacters = characters.filter(char =>
char.eye_color && char.eye_color.toLowerCase().includes(color)
);
}
} else if (attributeLower.includes('hair')) {
const color = attributeLower.match(/(red|blue|green|yellow|brown|black|white|orange|purple|pink|blond|blonde)/)?.[1];
const color = attributeLower.match(new RegExp(`(${SEARCH_COLORS.join('|')})`))?.[1];
if (color) {
matchedCharacters = characters.filter(char =>
char.hair_color && char.hair_color.toLowerCase().includes(color)
);
}
} else if (attributeLower.includes('skin')) {
const color = attributeLower.match(/(red|blue|green|yellow|brown|black|white|orange|purple|pink|fair|dark|light)/)?.[1];
const color = attributeLower.match(new RegExp(`(${SEARCH_COLORS.join('|')})`))?.[1];
if (color) {
matchedCharacters = characters.filter(char =>
char.skin_color && char.skin_color.toLowerCase().includes(color)
Expand Down Expand Up @@ -402,7 +426,11 @@ export const swapiTools = [
searchCharactersTool
];

// Helper function to extract entity IDs from vector search results
/**
* Helper function to extract entity IDs from vector search results
* @param {Array} context - Array of vector search result objects with metadata
* @returns {Object} Object containing arrays of entity IDs grouped by entity type
*/
export function extractEntityIds(context) {
const entityIds = {
characters: [],
Expand Down
42 changes: 42 additions & 0 deletions server/tests/langchain-tools-validation.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { jest } from '@jest/globals';
import { getCharacterTool } from '../langchain-tools.js';

describe('LangChain Tools Input Validation', () => {
describe('getCharacterTool input validation', () => {
it('should reject invalid ID (negative number)', async () => {
const result = await getCharacterTool.func({ id: -1 });
const parsed = JSON.parse(result);

expect(parsed.error).toBeDefined();
expect(parsed.error).toContain('Invalid character ID');
expect(parsed.error).toContain('-1');
});

it('should reject invalid ID (zero)', async () => {
const result = await getCharacterTool.func({ id: 0 });
const parsed = JSON.parse(result);

expect(parsed.error).toBeDefined();
expect(parsed.error).toContain('Invalid character ID');
expect(parsed.error).toContain('0');
});

it('should reject invalid ID (non-integer)', async () => {
const result = await getCharacterTool.func({ id: 1.5 });
const parsed = JSON.parse(result);

expect(parsed.error).toBeDefined();
expect(parsed.error).toContain('Invalid character ID');
expect(parsed.error).toContain('1.5');
});

it('should include API call info for invalid IDs', async () => {
const result = await getCharacterTool.func({ id: -1 });
const parsed = JSON.parse(result);

expect(parsed.apiCall).toBeDefined();
expect(parsed.apiCall.error).toContain('Invalid ID');
expect(parsed.apiCall.method).toBe('GET');
});
});
});