From ca06236c89375d9992e4508e2dada9a365777fd9 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 10:30:37 +0200 Subject: [PATCH 01/14] . --- lib/getModel.js | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/getModel.js b/lib/getModel.js index af0dab7..974ee65 100644 --- a/lib/getModel.js +++ b/lib/getModel.js @@ -1,6 +1,7 @@ import cds from '@sap/cds' import fs from 'fs' import path from 'path' +import calculateEmbeddings from './calculateEmbeddings.js' cds.log.Logger = () => { return { From 5138e8f7c66c18b7afd45c1b13decb85408773a9 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 10:31:04 +0200 Subject: [PATCH 02/14] . --- lib/getModel.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/getModel.js b/lib/getModel.js index 974ee65..201b5f1 100644 --- a/lib/getModel.js +++ b/lib/getModel.js @@ -112,6 +112,10 @@ async function compileModel(projectPath) { } }, intervalMs).unref() // Uses CDS_MCP_REFRESH_MS if set, otherwise defaults to 10x compile duration or 20s } + for (const key in compiled.definitions) { + const def = compiled.definitions[key] + Object.defineProperty(def, 'embeddings', { value: await calculateEmbeddings(key), enumerable: false }) + } return compiled } From a5b0263f579d77b1e8bd5c387110d87adfa6edec Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 10:31:38 +0200 Subject: [PATCH 03/14] . --- lib/tools.js | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/tools.js b/lib/tools.js index a70a71c..ac266be 100644 --- a/lib/tools.js +++ b/lib/tools.js @@ -17,15 +17,15 @@ const tools = { }, handler: async ({ projectPath, name, kind, topN, namesOnly }) => { const model = await getModel(projectPath) - const defNames = kind + const defs = kind ? Object.entries(model.definitions) // eslint-disable-next-line no-unused-vars .filter(([_k, v]) => v.kind === kind) - .map(([k]) => k) - : Object.keys(model.definitions) - const scores = name ? fuzzyTopN(name, defNames, topN) : fuzzyTopN('', defNames, topN) - if (namesOnly) return scores.map(s => s.item) - return scores.map(s => model.definitions[s.item]) + .map(([_k, v]) => v) + : Object.values(model.definitions) + const results = (await searchEmbeddings(name, defs)).slice(0, topN) + if (namesOnly) return results.map(r => r.name) + return results } }, search_docs: { From e6b8f1072b3497a26fee295bfecf180d6be0f1a3 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 10:32:35 +0200 Subject: [PATCH 04/14] . --- lib/tools.js | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/tools.js b/lib/tools.js index ac266be..726cd33 100644 --- a/lib/tools.js +++ b/lib/tools.js @@ -2,6 +2,7 @@ import { z } from 'zod' import getModel from './getModel.js' import fuzzyTopN from './fuzzyTopN.js' import searchMarkdownDocs from './searchMarkdownDocs.js' +import { searchEmbeddings } from './embeddings.js' const tools = { search_model: { From 95740c8c92082d81d35a7d6263f8c77fe6e1a57b Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 10:43:33 +0200 Subject: [PATCH 05/14] . --- lib/calculateEmbeddings.js | 1 + tests/tools.test.js | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/calculateEmbeddings.js b/lib/calculateEmbeddings.js index 98f85c2..3c4a331 100644 --- a/lib/calculateEmbeddings.js +++ b/lib/calculateEmbeddings.js @@ -130,6 +130,7 @@ async function initializeModelAndVocab() { */ function normalizeText(text) { // Convert to NFD normalization (decomposed) + if (!text) return '' text = text.normalize('NFD') // Remove control characters except whitespace diff --git a/tests/tools.test.js b/tests/tools.test.js index 8babbe9..a648633 100644 --- a/tests/tools.test.js +++ b/tests/tools.test.js @@ -36,7 +36,7 @@ test.describe('tools', () => { // Entity endpoints const books = await tools.search_model.handler({ projectPath: sampleProjectPath, - name: 'Books', + name: 'AdminService.Books', kind: 'entity', topN: 2 }) From 9bea98c6a44fcae9d1d1ac8575af588751b81337 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 11:06:13 +0200 Subject: [PATCH 06/14] README --- README.md | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 14e2657..d1aa074 100644 --- a/README.md +++ b/README.md @@ -118,24 +118,26 @@ cds-mcp search_docs "how to add columns to a select statement in CAP Node.js" 1 The server provides these tools for CAP development: +### Embeddings Search Technology + +Both tools leverage vector embeddings for intelligent search capabilities. This process works as follows: + +1. **Query processing:** Your search query is converted to an embedding vector. +2. **Similarity search:** The system finds content with the highest semantic similarity to your query. + +This semantic search approach enables you to find relevant content even when your query does not use the exact keywords, all locally on your machine. + ### `search_model` -This tool performs fuzzy searches against names of definitions from the compiled CDS model (Core Schema Notation). +This tool searches against definition names from the compiled CDS model (Core Schema Notation). CDS compiles all your `.cds` files into a unified model representation that includes: - All definitions and their relationships - Annotations - HTTP endpoints -The fuzzy search algorithm matches definition names and allows for partial matches, making it easy to find entities like "Books" even when searching for "book". - ### `search_docs` -This tool uses vector embeddings to locally search through preprocessed CAP documentation, stored as embeddings. The process works as follows: - -1. **Query processing:** Your search query is converted to an embedding vector. -2. **Similarity search:** The system finds documentation chunks with the highest semantic similarity to your query. - -This semantic search approach enables you to find relevant documentation even when your query does not use the exact keywords found in the docs, all locally on your machine. +This tool searches through preprocessed CAP documentation from capire (static), stored as embeddings. The embeddings are created from documentation chunks with a focus on code snippets. ## Support, Feedback, Contributing From 0e6aa6ad559044e24884d296bf3d888edc555212 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 11:12:41 +0200 Subject: [PATCH 07/14] . --- lib/embeddings.js | 11 +++++------ lib/fuzzyTopN.js | 35 ----------------------------------- lib/tools.js | 1 - tests/tools.test.js | 2 +- 4 files changed, 6 insertions(+), 43 deletions(-) delete mode 100644 lib/fuzzyTopN.js diff --git a/lib/embeddings.js b/lib/embeddings.js index dda0f98..1729eeb 100644 --- a/lib/embeddings.js +++ b/lib/embeddings.js @@ -106,13 +106,12 @@ export async function getEmbeddings(text) { export async function searchEmbeddings(query, chunks) { const search = await getEmbeddings(query) // Compute similarity for all chunks - const scoredChunks = chunks.map(chunk => ({ - ...chunk, - similarity: cosineSimilarity(search, chunk.embeddings) - })) + chunks.forEach(chunk => { + chunk.similarity = cosineSimilarity(search, chunk.embeddings) + }) // Sort by similarity descending - scoredChunks.sort((a, b) => b.similarity - a.similarity) - return scoredChunks + chunks.sort((a, b) => b.similarity - a.similarity) + return chunks } // Only to be used in scripts, not in production diff --git a/lib/fuzzyTopN.js b/lib/fuzzyTopN.js deleted file mode 100644 index 44f0f6e..0000000 --- a/lib/fuzzyTopN.js +++ /dev/null @@ -1,35 +0,0 @@ -export default function fuzzyTopN(searchTerm, list, n, min) { - function modifiedLevenshtein(a, b) { - const m = a.length - const n = b.length - const matrix = Array.from({ length: m + 1 }, () => Array(n + 1).fill(0)) - - for (let i = 0; i <= m; i++) matrix[i][0] = i * 0.5 - for (let j = 0; j <= n; j++) matrix[0][j] = j * 0.5 - - for (let i = 1; i <= m; i++) { - for (let j = 1; j <= n; j++) { - const cost = a[i - 1] === b[j - 1] ? 0 : 1 - matrix[i][j] = Math.min( - matrix[i - 1][j] + 0.5, // deletion - matrix[i][j - 1] + 0.5, // insertion - matrix[i - 1][j - 1] + cost // substitution - ) - } - } - - return matrix[m][n] - } - - function score(term, content) { - term = term.toLowerCase() - content = content.toLowerCase() - const distance = modifiedLevenshtein(term, content) - const maxLength = Math.max(term.length, content.length) - return maxLength === 0 ? 1 : 1 - distance / maxLength - } - - let result = list.map(item => ({ item, score: score(searchTerm, item) })) - if (min) result = result.filter(entry => entry.score >= min) - return result.sort((a, b) => b.score - a.score).slice(0, n) -} diff --git a/lib/tools.js b/lib/tools.js index 726cd33..5fd2d6e 100644 --- a/lib/tools.js +++ b/lib/tools.js @@ -1,6 +1,5 @@ import { z } from 'zod' import getModel from './getModel.js' -import fuzzyTopN from './fuzzyTopN.js' import searchMarkdownDocs from './searchMarkdownDocs.js' import { searchEmbeddings } from './embeddings.js' diff --git a/tests/tools.test.js b/tests/tools.test.js index a648633..7556e30 100644 --- a/tests/tools.test.js +++ b/tests/tools.test.js @@ -48,7 +48,7 @@ test.describe('tools', () => { test('search_model: fuzzy search for Books entity', async () => { const books = await tools.search_model.handler({ projectPath: sampleProjectPath, - name: 'Books', + name: 'AmiSrvice.Books', // intentional typo kind: 'entity', topN: 2 }) From 86660b6f1de1b6e38d9068239f8716018b6810c0 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 11:16:14 +0200 Subject: [PATCH 08/14] . --- lib/tools.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/tools.js b/lib/tools.js index 5fd2d6e..6382b0a 100644 --- a/lib/tools.js +++ b/lib/tools.js @@ -31,7 +31,7 @@ const tools = { search_docs: { title: 'Search in CAP Documentation', description: - "Searches code snippets of CAP documentation for the given query. You MUST use this tool if you're unsure about CAP APIs for CDS, Node.js or Java. Optionally returns only code blocks.", + "Searches code snippets of CAP documentation for the given query. You MUST use this tool if you're unsure about CAP APIs for CDS, Node.js or Java.", inputSchema: { query: z.string().describe('Search string'), maxResults: z.number().default(10).describe('Maximum number of results') From 52f01ce31ad38130b12e3b6783e03593b1f0db93 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 11:21:39 +0200 Subject: [PATCH 09/14] restructure --- CHANGELOG.md | 6 ++++++ CONTRIBUTING.md | 14 +++++++------- README.md | 23 ++++------------------- lib/searchModel.js | 10 ++++++++++ lib/tools.js | 14 ++------------ 5 files changed, 29 insertions(+), 38 deletions(-) create mode 100644 lib/searchModel.js diff --git a/CHANGELOG.md b/CHANGELOG.md index 6207bfc..d95afc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). The format is based on [Keep a Changelog](http://keepachangelog.com/). +## Version 0.0.4 - 2025-09-26 + +### Changed + +- Embeddings-based model search + ## Version 0.0.3 - 2025-09-22 ### Changed diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index aed3868..44d5e8a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,11 +10,11 @@ Instances of abusive, harassing, or otherwise unacceptable behavior may be repor We use GitHub to manage reviews of pull requests. -* If you are a new contributor, see: [Steps to Contribute](#steps-to-contribute) +- If you are a new contributor, see: [Steps to Contribute](#steps-to-contribute) -* Before implementing your change, create an issue that describes the problem you would like to solve or the code that should be enhanced. Please note that you are willing to work on that issue. +- Before implementing your change, create an issue that describes the problem you would like to solve or the code that should be enhanced. Please note that you are willing to work on that issue. -* The team will review the issue and decide whether it should be implemented as a pull request. In that case, they will assign the issue to you. If the team decides against picking up the issue, the team will post a comment with an explanation. +- The team will review the issue and decide whether it should be implemented as a pull request. In that case, they will assign the issue to you. If the team decides against picking up the issue, the team will post a comment with an explanation. ## Steps to Contribute @@ -28,11 +28,11 @@ You are welcome to contribute code in order to fix a bug or to implement a new f The following rule governs code contributions: -* Contributions must be licensed under the [Apache 2.0 License](./LICENSE) -* Due to legal reasons, contributors will be asked to accept a Developer Certificate of Origin (DCO) when they create the first pull request to this project. This happens in an automated fashion during the submission process. SAP uses [the standard DCO text of the Linux Foundation](https://developercertificate.org/). +- Contributions must be licensed under the [Apache 2.0 License](./LICENSE) +- Due to legal reasons, contributors will be asked to accept a Developer Certificate of Origin (DCO) when they create the first pull request to this project. This happens in an automated fashion during the submission process. SAP uses [the standard DCO text of the Linux Foundation](https://developercertificate.org/). ## Issues and Planning -* We use GitHub issues to track bugs and enhancement requests. +- We use GitHub issues to track bugs and enhancement requests. -* Please provide as much context as possible when you open an issue. The information you provide must be comprehensive enough to reproduce that issue for the assignee. +- Please provide as much context as possible when you open an issue. The information you provide must be comprehensive enough to reproduce that issue for the assignee. diff --git a/README.md b/README.md index d1aa074..c8ae572 100644 --- a/README.md +++ b/README.md @@ -2,25 +2,20 @@ [![REUSE status](https://api.reuse.software/badge/github.com/cap-js/mcp-server)](https://api.reuse.software/info/github.com/cap-js/mcp-server) - - > [!NOTE] > This project is in alpha state. - - ## About This Project A Model Context Protocol (MCP) server for the [SAP Cloud Application Programming Model (CAP)](https://cap.cloud.sap). Use it for AI-assisted development of CAP applications (_agentic coding_). The server helps AI models answer questions such as: + - _Which CDS services are in this project, and where are they served?_ - _What are the entities about and how do they relate?_ - _How do I add columns to a select statement in CAP Node.js?_ - - ## Table of Contents - [About This Project](#about-this-project) @@ -38,14 +33,10 @@ The server helps AI models answer questions such as: - [Licensing](#licensing) - [Acknowledgments](#acknowledgments) - - ## Requirements See [Getting Started](https://cap.cloud.sap/docs/get-started) on how to jumpstart your development and grow as you go with SAP Cloud Application Programming Model. - - ## Setup ```sh @@ -59,6 +50,7 @@ Configure your MCP client (Cline, opencode, Claude Code, GitHub Copilot, etc.) t ### Usage in VS Code Example for VS Code extension [Cline](https://marketplace.visualstudio.com/items?itemName=saoudrizwan.claude-dev): + ```json { "mcpServers": { @@ -76,6 +68,7 @@ See [VS Code Marketplace](https://marketplace.visualstudio.com/search?term=tag%3 ### Usage in opencode Example for [opencode](https://github.com/sst/opencode): + ```json { "mcp": { @@ -131,6 +124,7 @@ This semantic search approach enables you to find relevant content even when you This tool searches against definition names from the compiled CDS model (Core Schema Notation). CDS compiles all your `.cds` files into a unified model representation that includes: + - All definitions and their relationships - Annotations - HTTP endpoints @@ -139,31 +133,22 @@ CDS compiles all your `.cds` files into a unified model representation that incl This tool searches through preprocessed CAP documentation from capire (static), stored as embeddings. The embeddings are created from documentation chunks with a focus on code snippets. - ## Support, Feedback, Contributing This project is open to feature requests/suggestions, bug reports, and so on, via [GitHub issues](https://github.com/cap-js/mcp-server/issues). Contribution and feedback are encouraged and always welcome. For more information about how to contribute, the project structure, as well as additional contribution information, see our [Contribution Guidelines](CONTRIBUTING.md). - - ## Security / Disclosure If you find any bug that may be a security problem, please follow our instructions at [in our security policy](https://github.com/cap-js/mcp-server/security/policy) on how to report it. Please don't create GitHub issues for security-related doubts or problems. - - ## Code of Conduct We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone. By participating in this project, you agree to abide by its [Code of Conduct](https://github.com/cap-js/.github/blob/main/CODE_OF_CONDUCT.md) at all times. - - ## Licensing Copyright 2025 SAP SE or an SAP affiliate company and @cap-js/cds-mcp contributors. Please see our [LICENSE](LICENSE) for copyright and license information. Detailed information including third-party components and their licensing/copyright information is available [via the REUSE tool](https://api.reuse.software/info/github.com/cap-js/mcp-server). - - ## Acknowledgments - **onnxruntime-web** is used for creating embeddings locally. diff --git a/lib/searchModel.js b/lib/searchModel.js new file mode 100644 index 0000000..ff329a3 --- /dev/null +++ b/lib/searchModel.js @@ -0,0 +1,10 @@ +import getModel from './getModel.js' +import { searchEmbeddings } from './embeddings.js' + +export default async function searchModel(projectPath, name, kind, topN, namesOnly) { + const model = await getModel(projectPath) + const defs = kind ? Object.values(model.definitions).filter(v => v.kind === kind) : Object.values(model.definitions) + const results = (await searchEmbeddings(name, defs)).slice(0, topN) + if (namesOnly) return results.map(r => r.name) + return results +} diff --git a/lib/tools.js b/lib/tools.js index 6382b0a..633bb19 100644 --- a/lib/tools.js +++ b/lib/tools.js @@ -1,7 +1,6 @@ import { z } from 'zod' -import getModel from './getModel.js' import searchMarkdownDocs from './searchMarkdownDocs.js' -import { searchEmbeddings } from './embeddings.js' +import searchModel from './searchModel.js' const tools = { search_model: { @@ -16,16 +15,7 @@ const tools = { namesOnly: z.boolean().default(false).describe('If true, only return definition names (for overview)') }, handler: async ({ projectPath, name, kind, topN, namesOnly }) => { - const model = await getModel(projectPath) - const defs = kind - ? Object.entries(model.definitions) - // eslint-disable-next-line no-unused-vars - .filter(([_k, v]) => v.kind === kind) - .map(([_k, v]) => v) - : Object.values(model.definitions) - const results = (await searchEmbeddings(name, defs)).slice(0, topN) - if (namesOnly) return results.map(r => r.name) - return results + return await searchModel(projectPath, name, kind, topN, namesOnly) } }, search_docs: { From d40cb213eea1f772ae4a18edb63a91dea484deb5 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 11:24:07 +0200 Subject: [PATCH 10/14] . --- lib/searchModel.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/searchModel.js b/lib/searchModel.js index ff329a3..8fe639f 100644 --- a/lib/searchModel.js +++ b/lib/searchModel.js @@ -4,7 +4,7 @@ import { searchEmbeddings } from './embeddings.js' export default async function searchModel(projectPath, name, kind, topN, namesOnly) { const model = await getModel(projectPath) const defs = kind ? Object.values(model.definitions).filter(v => v.kind === kind) : Object.values(model.definitions) - const results = (await searchEmbeddings(name, defs)).slice(0, topN) + const results = name?.length ? (await searchEmbeddings(name, defs)).slice(0, topN) : defs.slice(0, topN) if (namesOnly) return results.map(r => r.name) return results } From bd712176ac37fda6f210b156128ed66838d3110d Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 11:25:21 +0200 Subject: [PATCH 11/14] . --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c8ae572..9191843 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ CDS compiles all your `.cds` files into a unified model representation that incl ### `search_docs` -This tool searches through preprocessed CAP documentation from capire (static), stored as embeddings. The embeddings are created from documentation chunks with a focus on code snippets. +This tool searches through preprocessed CAP documentation from capire with a focus on code snippets, stored as embeddings. ## Support, Feedback, Contributing From 88f67f54b5d4c770090042a0562755dbb5446938 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 15:25:47 +0200 Subject: [PATCH 12/14] disable logs --- CHANGELOG.md | 2 +- lib/calculateEmbeddings.js | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d95afc9..69295a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). The format is based on [Keep a Changelog](http://keepachangelog.com/). -## Version 0.0.4 - 2025-09-26 +## Version 0.0.4 - ? ### Changed diff --git a/lib/calculateEmbeddings.js b/lib/calculateEmbeddings.js index 3c4a331..08f57bb 100644 --- a/lib/calculateEmbeddings.js +++ b/lib/calculateEmbeddings.js @@ -4,6 +4,9 @@ import path from 'path' import { fileURLToPath } from 'url' import * as ort from 'onnxruntime-web' +ort.env.debug = false +ort.env.logLevel = 'error' + const __dirname = path.dirname(fileURLToPath(import.meta.url)) const MODEL_NAME = 'Xenova/all-MiniLM-L6-v2' From fc7ec6c6662a33fb2f81299b8afd6029428c0592 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 16:38:17 +0200 Subject: [PATCH 13/14] batch --- lib/calculateEmbeddings.js | 208 ++++++++++++++++--------- lib/embeddings.js | 11 +- tests/calculateEmbeddingsBatch.test.js | 167 ++++++++++++++++++++ 3 files changed, 302 insertions(+), 84 deletions(-) create mode 100644 tests/calculateEmbeddingsBatch.test.js diff --git a/lib/calculateEmbeddings.js b/lib/calculateEmbeddings.js index 08f57bb..5d14918 100644 --- a/lib/calculateEmbeddings.js +++ b/lib/calculateEmbeddings.js @@ -328,100 +328,138 @@ function wordPieceTokenizer(text, vocab, maxLength = 512) { return chunks } +let session = null +let vocab = null + +// Start downloading and initializing model when module loads +const modelInitPromise = (async () => { + try { + await downloadModelIfNeeded() + await initializeModelAndVocab() + } catch { + // Don't throw here - let the main function handle initialization + } +})() + +export function resetSession() { + session = null + vocab = null +} + /** - * Process embeddings for multiple chunks and combine them + * Process multiple texts in a single batch inference call */ -async function processChunkedEmbeddings(chunks, session) { - const embeddings = [] - - for (const chunk of chunks) { - const { ids } = chunk +async function processBatchEmbeddings(batchTokenData, session) { + const { inputIds, attentionMask, tokenTypeIds, batchSize, maxSeqLength, hiddenSize } = batchTokenData - // ONNX Runtime input tensors must be int64 (BigInt64Array) - // Add validation for token IDs before converting to BigInt - const validIds = ids.filter(id => { - const isValid = typeof id === 'number' && !isNaN(id) && isFinite(id) - if (!isValid) { - throw new Error(`Invalid token ID detected: ${id} (type: ${typeof id})`) - } - return isValid - }) + const inputTensor = new ort.Tensor('int64', inputIds, [batchSize, maxSeqLength]) + const attentionTensor = new ort.Tensor('int64', attentionMask, [batchSize, maxSeqLength]) + const tokenTypeTensor = new ort.Tensor('int64', tokenTypeIds, [batchSize, maxSeqLength]) - if (validIds.length !== ids.length) { - throw new Error(`Found ${ids.length - validIds.length} invalid token IDs`) - } + const feeds = { + input_ids: inputTensor, + attention_mask: attentionTensor, + token_type_ids: tokenTypeTensor + } - const inputIds = new BigInt64Array(validIds.map(i => BigInt(i))) - const attentionMask = new BigInt64Array(validIds.length).fill(BigInt(1)) - const tokenTypeIds = new BigInt64Array(validIds.length).fill(BigInt(0)) + const results = await session.run(feeds) + const lastHiddenState = results['last_hidden_state'] + const embeddingData = lastHiddenState.data - const inputTensor = new ort.Tensor('int64', inputIds, [1, validIds.length]) - const attentionTensor = new ort.Tensor('int64', attentionMask, [1, validIds.length]) - const tokenTypeTensor = new ort.Tensor('int64', tokenTypeIds, [1, validIds.length]) + // Extract embeddings for each item in batch + const embeddings = [] + for (let batchIdx = 0; batchIdx < batchSize; batchIdx++) { + const pooledEmbedding = new Float32Array(hiddenSize) - const feeds = { - input_ids: inputTensor, - attention_mask: attentionTensor, - token_type_ids: tokenTypeTensor + // Calculate valid sequence length for this batch item (excluding padding) + let validSeqLength = 0 + for (let seqIdx = 0; seqIdx < maxSeqLength; seqIdx++) { + if (attentionMask[batchIdx * maxSeqLength + seqIdx] === BigInt(1)) { + validSeqLength++ + } } - const results = await session.run(feeds) - const lastHiddenState = results['last_hidden_state'] - const [, sequenceLength, hiddenSize] = lastHiddenState.dims - const embeddingData = lastHiddenState.data - - // Apply mean pooling across the sequence dimension - const pooledEmbedding = new Float32Array(hiddenSize) - for (let i = 0; i < hiddenSize; i++) { + // Apply mean pooling across the valid sequence dimension + for (let hiddenIdx = 0; hiddenIdx < hiddenSize; hiddenIdx++) { let sum = 0 - for (let j = 0; j < sequenceLength; j++) { - sum += embeddingData[j * hiddenSize + i] + for (let seqIdx = 0; seqIdx < validSeqLength; seqIdx++) { + const dataIdx = batchIdx * maxSeqLength * hiddenSize + seqIdx * hiddenSize + hiddenIdx + sum += embeddingData[dataIdx] } - pooledEmbedding[i] = sum / sequenceLength + pooledEmbedding[hiddenIdx] = sum / validSeqLength } embeddings.push(pooledEmbedding) } - // If multiple chunks, average the embeddings - if (embeddings.length === 1) { - return embeddings[0] - } - - const hiddenSize = embeddings[0].length - const avgEmbedding = new Float32Array(hiddenSize) + return embeddings +} - // Average across all chunks - for (let i = 0; i < hiddenSize; i++) { - let sum = 0 - for (const embedding of embeddings) { - sum += embedding[i] +/** + * Prepare batch data for inference - handles padding and creates tensors + */ +function prepareBatchTokenData(allChunks) { + // Find the maximum sequence length across all chunks + let maxSeqLength = 0 + for (const chunks of allChunks) { + for (const chunk of chunks) { + maxSeqLength = Math.max(maxSeqLength, chunk.ids.length) } - avgEmbedding[i] = sum / embeddings.length } - return avgEmbedding -} + const batchSize = allChunks.length + const hiddenSize = 384 // MiniLM-L6-v2 hidden size -let session = null -let vocab = null + // Pre-allocate arrays for batch data + const inputIds = new BigInt64Array(batchSize * maxSeqLength) + const attentionMask = new BigInt64Array(batchSize * maxSeqLength) + const tokenTypeIds = new BigInt64Array(batchSize * maxSeqLength) -// Start downloading and initializing model when module loads -const modelInitPromise = (async () => { - try { - await downloadModelIfNeeded() - await initializeModelAndVocab() - } catch { - // Don't throw here - let the main function handle initialization + // Fill batch data + for (let batchIdx = 0; batchIdx < batchSize; batchIdx++) { + const chunks = allChunks[batchIdx] + + // For now, just use the first chunk (most texts will be single chunk) + // TODO: Handle multi-chunk texts properly + const chunk = chunks[0] + const ids = chunk.ids + + const baseOffset = batchIdx * maxSeqLength + + // Fill actual token data + for (let seqIdx = 0; seqIdx < ids.length && seqIdx < maxSeqLength; seqIdx++) { + const id = ids[seqIdx] + if (typeof id !== 'number' || isNaN(id) || !isFinite(id)) { + throw new Error(`Invalid token ID: ${id}`) + } + + inputIds[baseOffset + seqIdx] = BigInt(id) + attentionMask[baseOffset + seqIdx] = BigInt(1) + tokenTypeIds[baseOffset + seqIdx] = BigInt(0) + } + + // Padding is already zero-filled (BigInt64Array defaults to 0) + // Attention mask for padding positions remains 0 } -})() -export function resetSession() { - session = null - vocab = null + return { + inputIds, + attentionMask, + tokenTypeIds, + batchSize, + maxSeqLength, + hiddenSize + } } -export default async function calculateEmbeddings(text) { +/** + * Batch processing function for multiple texts + */ +export async function calculateEmbeddingsBatch(texts) { + if (!Array.isArray(texts) || texts.length === 0) { + throw new Error('Input must be a non-empty array of strings') + } + // Wait for the model to be preloaded, then ensure it's initialized await modelInitPromise @@ -429,7 +467,16 @@ export default async function calculateEmbeddings(text) { await initializeModelAndVocab() } - const chunks = wordPieceTokenizer(text, vocab) + // Tokenize all texts in parallel + const allChunks = await Promise.all(texts.map(text => Promise.resolve(wordPieceTokenizer(text, vocab)))) + + // Check for multi-chunk texts (not fully supported yet) + const hasMultiChunk = allChunks.some(chunks => chunks.length > 1) + if (hasMultiChunk) { + // Fall back to individual processing for multi-chunk texts + console.warn('Multi-chunk texts detected, falling back to individual processing') + return Promise.all(texts.map(text => calculateEmbeddings(text))) + } function normalizeEmbedding(embedding) { let norm = 0 @@ -446,16 +493,25 @@ export default async function calculateEmbeddings(text) { } try { - const pooledEmbedding = await processChunkedEmbeddings(chunks, session) - return normalizeEmbedding(pooledEmbedding) - } catch { - // If inference fails, it might be due to model corruption - // Try to recover by re-downloading and reinitializing + const batchTokenData = prepareBatchTokenData(allChunks) + const embeddings = await processBatchEmbeddings(batchTokenData, session) + + // Normalize all embeddings + return embeddings.map(embedding => normalizeEmbedding(embedding)) + } catch (error) { + // If inference fails, try to recover by re-downloading and reinitializing + console.warn('Batch inference failed, attempting recovery:', error.message) await forceRedownloadModel() await initializeModelAndVocab() - const retryPooledEmbedding = await processChunkedEmbeddings(chunks, session) - return normalizeEmbedding(retryPooledEmbedding) + const batchTokenData = prepareBatchTokenData(allChunks) + const retryEmbeddings = await processBatchEmbeddings(batchTokenData, session) + return retryEmbeddings.map(embedding => normalizeEmbedding(embedding)) } } + +export default async function calculateEmbeddings(text) { + const result = await calculateEmbeddingsBatch([text]) + return result[0] +} diff --git a/lib/embeddings.js b/lib/embeddings.js index 1729eeb..9794e98 100644 --- a/lib/embeddings.js +++ b/lib/embeddings.js @@ -1,7 +1,7 @@ import fs from 'fs/promises' import path from 'path' import { fileURLToPath } from 'url' -import calculateEmbeddings from './calculateEmbeddings.js' +import calculateEmbeddings, { calculateEmbeddingsBatch } from './calculateEmbeddings.js' const __dirname = path.dirname(fileURLToPath(import.meta.url)) export async function loadChunks(id, dir = path.join(__dirname, '..', 'embeddings')) { @@ -116,13 +116,8 @@ export async function searchEmbeddings(query, chunks) { // Only to be used in scripts, not in production export async function createEmbeddings(id, chunks, dir = path.join(__dirname, '..', 'embeddings')) { - const embeddings = [] - - for (let i = 0; i < chunks.length; i++) { - const embedding = await getEmbeddings(chunks[i]) - embeddings.push(embedding) - } - + // Use batch processing for better performance + const embeddings = await calculateEmbeddingsBatch(chunks) await saveEmbeddings(id, chunks, embeddings, dir) } diff --git a/tests/calculateEmbeddingsBatch.test.js b/tests/calculateEmbeddingsBatch.test.js new file mode 100644 index 0000000..e4c6a4a --- /dev/null +++ b/tests/calculateEmbeddingsBatch.test.js @@ -0,0 +1,167 @@ +import { test, describe } from 'node:test' +import assert from 'node:assert' +import calculateEmbeddings, { calculateEmbeddingsBatch } from '../lib/calculateEmbeddings.js' + +function arraysAlmostEqual(arr1, arr2, tolerance = 1e-6) { + if (arr1.length !== arr2.length) return false + + for (let i = 0; i < arr1.length; i++) { + if (Math.abs(arr1[i] - arr2[i]) > tolerance) { + return false + } + } + return true +} + +describe('calculateEmbeddingsBatch', () => { + test('should produce same results as individual calls for simple texts', async () => { + const texts = ['hello world', 'goodbye world', 'test string'] + + // Get individual embeddings + const individualEmbeddings = await Promise.all(texts.map(text => calculateEmbeddings(text))) + + // Get batch embeddings + const batchEmbeddings = await calculateEmbeddingsBatch(texts) + + // Verify same number of results + assert.strictEqual(batchEmbeddings.length, individualEmbeddings.length) + assert.strictEqual(batchEmbeddings.length, texts.length) + + // Verify each embedding matches + for (let i = 0; i < texts.length; i++) { + assert.ok( + arraysAlmostEqual(individualEmbeddings[i], batchEmbeddings[i]), + `Embedding ${i} for text "${texts[i]}" does not match between individual and batch processing` + ) + } + }) + + test('should handle single text input', async () => { + const text = 'single test string' + + const individual = await calculateEmbeddings(text) + const batch = await calculateEmbeddingsBatch([text]) + + assert.strictEqual(batch.length, 1) + assert.ok(arraysAlmostEqual(individual, batch[0]), 'Single text batch result should match individual result') + }) + + test('should handle empty array input', async () => { + await assert.rejects(() => calculateEmbeddingsBatch([]), /Input must be a non-empty array of strings/) + }) + + test('should handle different length texts', async () => { + const texts = ['short', 'this is a medium length sentence with some words', 'a'] + + const individualEmbeddings = await Promise.all(texts.map(text => calculateEmbeddings(text))) + + const batchEmbeddings = await calculateEmbeddingsBatch(texts) + + assert.strictEqual(batchEmbeddings.length, texts.length) + + for (let i = 0; i < texts.length; i++) { + assert.ok( + arraysAlmostEqual(individualEmbeddings[i], batchEmbeddings[i]), + `Variable length embedding ${i} does not match` + ) + } + }) + + test('should produce normalized embeddings', async () => { + const texts = ['test vector normalization', 'another test vector'] + const embeddings = await calculateEmbeddingsBatch(texts) + + for (let i = 0; i < embeddings.length; i++) { + // Calculate L2 norm + let norm = 0 + for (let j = 0; j < embeddings[i].length; j++) { + norm += embeddings[i][j] * embeddings[i][j] + } + norm = Math.sqrt(norm) + + // Should be approximately 1.0 (normalized) + assert.ok(Math.abs(norm - 1.0) < 1e-6, `Embedding ${i} is not normalized (norm: ${norm})`) + } + }) + + test('should handle special characters and punctuation', async () => { + const texts = ['Hello, world!', 'Test with "quotes" and symbols: @#$%', 'Unicode: café, naïve, résumé'] + + const individualEmbeddings = await Promise.all(texts.map(text => calculateEmbeddings(text))) + + const batchEmbeddings = await calculateEmbeddingsBatch(texts) + + for (let i = 0; i < texts.length; i++) { + assert.ok( + arraysAlmostEqual(individualEmbeddings[i], batchEmbeddings[i]), + `Special character embedding ${i} does not match` + ) + } + }) + + test('should be faster than individual processing for multiple texts', async () => { + const texts = Array.from({ length: 10 }, (_, i) => `test string number ${i} with some content`) + + // Time individual processing + const startIndividual = process.hrtime.bigint() + await Promise.all(texts.map(text => calculateEmbeddings(text))) + const endIndividual = process.hrtime.bigint() + const individualTime = Number(endIndividual - startIndividual) / 1e6 + + // Time batch processing + const startBatch = process.hrtime.bigint() + await calculateEmbeddingsBatch(texts) + const endBatch = process.hrtime.bigint() + const batchTime = Number(endBatch - startBatch) / 1e6 + + console.log(`Individual processing: ${individualTime.toFixed(2)}ms`) + console.log(`Batch processing: ${batchTime.toFixed(2)}ms`) + console.log(`Speedup: ${(individualTime / batchTime).toFixed(2)}x`) + + // Batch should be faster (or at least not significantly slower) + // Allow some tolerance since timing can be variable + assert.ok( + batchTime <= individualTime * 1.2, + `Batch processing (${batchTime}ms) should not be significantly slower than individual processing (${individualTime}ms)` + ) + }) + + test('performance comparison across different batch sizes', async () => { + const batchSizes = [1, 5, 10, 20, 50] + const results = [] + + for (const size of batchSizes) { + const texts = Array.from({ length: size }, (_, i) => `performance test string ${i} with varying content length`) + + // Time individual processing + const startIndividual = process.hrtime.bigint() + await Promise.all(texts.map(text => calculateEmbeddings(text))) + const endIndividual = process.hrtime.bigint() + const individualTime = Number(endIndividual - startIndividual) / 1e6 + + // Time batch processing + const startBatch = process.hrtime.bigint() + await calculateEmbeddingsBatch(texts) + const endBatch = process.hrtime.bigint() + const batchTime = Number(endBatch - startBatch) / 1e6 + + const speedup = individualTime / batchTime + results.push({ size, individualTime, batchTime, speedup }) + } + + // Verify that larger batch sizes generally show better speedup + const largeBatch = results.find(r => r.size >= 20) + const smallBatch = results.find(r => r.size <= 5) + + if (largeBatch && smallBatch) { + assert.ok( + largeBatch.speedup >= smallBatch.speedup * 0.8, // Allow some tolerance + `Larger batches should generally be more efficient. Large batch speedup: ${largeBatch.speedup}x, Small batch speedup: ${smallBatch.speedup}x` + ) + } + + // At least one batch size should show improvement + const bestSpeedup = Math.max(...results.map(r => r.speedup)) + assert.ok(bestSpeedup > 1.0, `Should show speedup for at least one batch size. Best: ${bestSpeedup.toFixed(2)}x`) + }) +}) From b3a3222e96ae60a0170c8984d7c8336a6ed335c1 Mon Sep 17 00:00:00 2001 From: "Dr. David A. Kunz" Date: Fri, 26 Sep 2025 17:25:46 +0200 Subject: [PATCH 14/14] . --- tests/calculateEmbeddingsBatch.test.js | 66 -------------------------- 1 file changed, 66 deletions(-) diff --git a/tests/calculateEmbeddingsBatch.test.js b/tests/calculateEmbeddingsBatch.test.js index e4c6a4a..f07ab53 100644 --- a/tests/calculateEmbeddingsBatch.test.js +++ b/tests/calculateEmbeddingsBatch.test.js @@ -98,70 +98,4 @@ describe('calculateEmbeddingsBatch', () => { ) } }) - - test('should be faster than individual processing for multiple texts', async () => { - const texts = Array.from({ length: 10 }, (_, i) => `test string number ${i} with some content`) - - // Time individual processing - const startIndividual = process.hrtime.bigint() - await Promise.all(texts.map(text => calculateEmbeddings(text))) - const endIndividual = process.hrtime.bigint() - const individualTime = Number(endIndividual - startIndividual) / 1e6 - - // Time batch processing - const startBatch = process.hrtime.bigint() - await calculateEmbeddingsBatch(texts) - const endBatch = process.hrtime.bigint() - const batchTime = Number(endBatch - startBatch) / 1e6 - - console.log(`Individual processing: ${individualTime.toFixed(2)}ms`) - console.log(`Batch processing: ${batchTime.toFixed(2)}ms`) - console.log(`Speedup: ${(individualTime / batchTime).toFixed(2)}x`) - - // Batch should be faster (or at least not significantly slower) - // Allow some tolerance since timing can be variable - assert.ok( - batchTime <= individualTime * 1.2, - `Batch processing (${batchTime}ms) should not be significantly slower than individual processing (${individualTime}ms)` - ) - }) - - test('performance comparison across different batch sizes', async () => { - const batchSizes = [1, 5, 10, 20, 50] - const results = [] - - for (const size of batchSizes) { - const texts = Array.from({ length: size }, (_, i) => `performance test string ${i} with varying content length`) - - // Time individual processing - const startIndividual = process.hrtime.bigint() - await Promise.all(texts.map(text => calculateEmbeddings(text))) - const endIndividual = process.hrtime.bigint() - const individualTime = Number(endIndividual - startIndividual) / 1e6 - - // Time batch processing - const startBatch = process.hrtime.bigint() - await calculateEmbeddingsBatch(texts) - const endBatch = process.hrtime.bigint() - const batchTime = Number(endBatch - startBatch) / 1e6 - - const speedup = individualTime / batchTime - results.push({ size, individualTime, batchTime, speedup }) - } - - // Verify that larger batch sizes generally show better speedup - const largeBatch = results.find(r => r.size >= 20) - const smallBatch = results.find(r => r.size <= 5) - - if (largeBatch && smallBatch) { - assert.ok( - largeBatch.speedup >= smallBatch.speedup * 0.8, // Allow some tolerance - `Larger batches should generally be more efficient. Large batch speedup: ${largeBatch.speedup}x, Small batch speedup: ${smallBatch.speedup}x` - ) - } - - // At least one batch size should show improvement - const bestSpeedup = Math.max(...results.map(r => r.speedup)) - assert.ok(bestSpeedup > 1.0, `Should show speedup for at least one batch size. Best: ${bestSpeedup.toFixed(2)}x`) - }) })