diff --git a/build.gradle.kts b/build.gradle.kts index b40ce41f..efc424ec 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,3 +1,6 @@ +import java.io.File +import java.net.URL + plugins { java // Must be compatible with Kotlin compiler bundled with IntelliJ IDEA 2025.3 (metadata 2.2.0) @@ -8,6 +11,8 @@ plugins { val ideaVersion = "2025.3" val javaVersion = 21 +val onnxRuntimeVersion = "1.20.0" +val djlTokenizersVersion = "0.33.0" group = "com.lsfusion" version = file("META-INF/plugin.xml").let { @@ -27,6 +32,15 @@ repositories { } } +configurations.all { + resolutionStrategy.eachDependency { + if (requested.group == "com.microsoft.onnxruntime") { + useVersion(onnxRuntimeVersion) + because("Keep all ONNX Runtime jars on one version") + } + } +} + java { // Compile with a stable JDK toolchain. // `runIde` in your environment was using javac 22.0.1 and crashed with StackOverflowError inside javac parser. @@ -111,6 +125,10 @@ dependencies { implementation("org.json:json:20180813") implementation("org.jsoup:jsoup:1.15.3") implementation("net.gcardone.junidecode:junidecode:0.5.2") + implementation("com.microsoft.onnxruntime:onnxruntime:$onnxRuntimeVersion") + implementation("ai.djl.huggingface:tokenizers:$djlTokenizersVersion") + implementation("org.apache.lucene:lucene-core:10.3.2") + implementation("org.apache.lucene:lucene-analysis-common:10.3.2") // Needed by MCP toolset DTOs; compiler plugin generates serializers. compileOnly("org.jetbrains.kotlinx:kotlinx-serialization-json:1.7.3") @@ -132,6 +150,78 @@ dependencies { } } +val modelDir = layout.projectDirectory.dir(".mcp-model") +val e5ModelUrl = "https://huggingface.co/intfloat/e5-small/resolve/main/model.onnx?download=true" +val e5TokenizerUrl = "https://huggingface.co/intfloat/e5-small/resolve/main/tokenizer.json?download=true" +val onnxNativeDir = layout.buildDirectory.dir("onnxruntime-native") +val onnxTempDir = layout.buildDirectory.dir("onnxruntime-tmp") +val onnxRuntimeDll = onnxNativeDir.map { it.file("onnxruntime.dll") } +val onnxRuntimeJniDll = onnxNativeDir.map { it.file("onnxruntime4j_jni.dll") } + +tasks.register("downloadE5Model") { + group = "mcp" + description = "Download e5-small ONNX model and tokenizer if missing" + notCompatibleWithConfigurationCache("Downloads model files to project directory") + doLast { + val dirFile = modelDir.asFile + if (!dirFile.exists()) { + dirFile.mkdirs() + } + val modelFile = modelDir.file("model.onnx").asFile + val tokenizerFile = modelDir.file("tokenizer.json").asFile + + fun downloadIfMissing(url: String, target: File) { + if (target.exists() && target.length() > 0) return + URL(url).openStream().use { input -> + target.outputStream().use { output -> input.copyTo(output) } + } + } + + downloadIfMissing(e5ModelUrl, modelFile) + downloadIfMissing(e5TokenizerUrl, tokenizerFile) + } +} + +val extractOnnxRuntimeNative by tasks.registering(Sync::class) { + group = "mcp" + description = "Extract ONNX Runtime native libs into build dir" + val runtimeClasspath = configurations.runtimeClasspath + from({ + runtimeClasspath.get() + .filter { it.name.startsWith("onnxruntime-") && it.extension == "jar" } + .map { zipTree(it) } + }) { + include("**/onnxruntime*.dll") + include("**/onnxruntime_providers_*.dll") + include("**/onnxruntime*.so") + include("**/onnxruntime*.dylib") + include("**/libonnxruntime*.dylib") + include("**/onnxruntime4j_jni*") + include("**/libonnxruntime4j_jni*") + } + into(onnxNativeDir) + includeEmptyDirs = false + duplicatesStrategy = org.gradle.api.file.DuplicatesStrategy.EXCLUDE + eachFile { path = name } +} + +val createOnnxTempDir by tasks.registering { + group = "mcp" + description = "Create ONNX Runtime temp dir" + doLast { + onnxTempDir.get().asFile.mkdirs() + } +} + +tasks.withType().configureEach { + dependsOn("downloadE5Model", extractOnnxRuntimeNative, createOnnxTempDir) + jvmArgs("-Dlsfusion.mcp.embedding.modelDir=${modelDir.asFile.absolutePath}") + jvmArgs("-Donnxruntime.native.path=${onnxNativeDir.get().asFile.absolutePath}") + jvmArgs("-Donnxruntime.native.onnxruntime.path=${onnxRuntimeDll.get().asFile.absolutePath}") + jvmArgs("-Donnxruntime.native.onnxruntime4j_jni.path=${onnxRuntimeJniDll.get().asFile.absolutePath}") + jvmArgs("-Djava.io.tmpdir=${onnxTempDir.get().asFile.absolutePath}") +} + intellijPlatform { pluginVerification { ides { diff --git a/src/com/lsfusion/LSFBaseStartupActivity.java b/src/com/lsfusion/LSFBaseStartupActivity.java index c5837e1f..14048ad3 100644 --- a/src/com/lsfusion/LSFBaseStartupActivity.java +++ b/src/com/lsfusion/LSFBaseStartupActivity.java @@ -7,6 +7,8 @@ import com.intellij.openapi.vfs.VirtualFileManager; import com.intellij.openapi.vfs.impl.BulkVirtualFileListenerAdapter; import com.lsfusion.actions.locale.LSFPropertiesFileListener; +import com.lsfusion.mcp.LSFMcpRagFileListener; +import com.lsfusion.mcp.LocalMcpRagService; import kotlin.Unit; import kotlin.coroutines.Continuation; import org.jetbrains.annotations.NotNull; @@ -17,6 +19,12 @@ public class LSFBaseStartupActivity implements ProjectActivity, DumbAware { public @Nullable Object execute(@NotNull Project project, @NotNull Continuation continuation) { project.getMessageBus().connect().subscribe(FileEditorManagerListener.FILE_EDITOR_MANAGER, new LSFFileEditorManagerListener()); project.getMessageBus().connect().subscribe(VirtualFileManager.VFS_CHANGES, new BulkVirtualFileListenerAdapter(new LSFPropertiesFileListener(project))); + project.getMessageBus().connect().subscribe(VirtualFileManager.VFS_CHANGES, new BulkVirtualFileListenerAdapter(new LSFMcpRagFileListener(project))); + try { + LocalMcpRagService.getInstance(project).indexProjectAsync(); + } catch (Exception ignored) { + // indexing service not available + } return Unit.INSTANCE; } } diff --git a/src/com/lsfusion/mcp/EmbeddingProvider.java b/src/com/lsfusion/mcp/EmbeddingProvider.java new file mode 100644 index 00000000..95193f26 --- /dev/null +++ b/src/com/lsfusion/mcp/EmbeddingProvider.java @@ -0,0 +1,11 @@ +package com.lsfusion.mcp; + +public interface EmbeddingProvider extends AutoCloseable { + float[] embed(String text) throws Exception; + int dimension(); + + @Override + default void close() throws Exception { + // no-op by default + } +} diff --git a/src/com/lsfusion/mcp/LSFMcpRagFileListener.java b/src/com/lsfusion/mcp/LSFMcpRagFileListener.java new file mode 100644 index 00000000..6e6944d5 --- /dev/null +++ b/src/com/lsfusion/mcp/LSFMcpRagFileListener.java @@ -0,0 +1,64 @@ +package com.lsfusion.mcp; + +import com.intellij.openapi.project.Project; +import com.intellij.openapi.vfs.VirtualFile; +import com.intellij.openapi.vfs.VirtualFileCopyEvent; +import com.intellij.openapi.vfs.VirtualFileEvent; +import com.intellij.openapi.vfs.VirtualFileListener; +import com.intellij.openapi.vfs.VirtualFileMoveEvent; +import com.intellij.openapi.vfs.VirtualFilePropertyEvent; +import org.jetbrains.annotations.NotNull; + +public class LSFMcpRagFileListener implements VirtualFileListener { + private final Project project; + + public LSFMcpRagFileListener(Project project) { + this.project = project; + } + + @Override + public void contentsChanged(@NotNull VirtualFileEvent event) { + update(event.getFile()); + } + + @Override + public void fileCreated(@NotNull VirtualFileEvent event) { + update(event.getFile()); + } + + @Override + public void fileDeleted(@NotNull VirtualFileEvent event) { + delete(event.getFile()); + } + + @Override + public void fileMoved(@NotNull VirtualFileMoveEvent event) { + update(event.getFile()); + } + + @Override + public void fileCopied(@NotNull VirtualFileCopyEvent event) { + update(event.getFile()); + } + + @Override + public void propertyChanged(@NotNull VirtualFilePropertyEvent event) { + update(event.getFile()); + } + + private void update(VirtualFile file) { + try { + LocalMcpRagService.getInstance(project).updateFile(file); + } catch (Exception ignored) { + // service not available + } + } + + private void delete(VirtualFile file) { + try { + LocalMcpRagService.getInstance(project).deleteFile(file); + } catch (Exception ignored) { + // service not available + } + } +} diff --git a/src/com/lsfusion/mcp/LocalMcpRagService.java b/src/com/lsfusion/mcp/LocalMcpRagService.java new file mode 100644 index 00000000..53b5ea54 --- /dev/null +++ b/src/com/lsfusion/mcp/LocalMcpRagService.java @@ -0,0 +1,340 @@ +package com.lsfusion.mcp; + +import com.intellij.openapi.application.ReadAction; +import com.intellij.openapi.diagnostic.Logger; +import com.intellij.openapi.project.Project; +import com.intellij.openapi.util.Key; +import com.intellij.openapi.vfs.VirtualFile; +import com.intellij.psi.PsiFile; +import com.intellij.psi.PsiManager; +import com.intellij.psi.search.ProjectScope; +import com.lsfusion.lang.psi.LSFFile; +import com.lsfusion.util.LSFFileUtils; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; +import org.apache.lucene.document.TextField; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.BytesRef; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; + +public final class LocalMcpRagService { + private static final Logger LOG = Logger.getInstance(LocalMcpRagService.class); + private static final Key KEY = Key.create("lsfusion.mcp.localRagService"); + private static final String FIELD_ID = "elementId"; + private static final String FIELD_FILE = "filePath"; + private static final String FIELD_MODULE = "module"; + private static final String FIELD_TYPE = "type"; + private static final String FIELD_CODE = "code"; + private static final String FIELD_VECTOR = "vector"; + + private final Project project; + private final Directory directory; + private final IndexWriter writer; + private final EmbeddingProvider embeddingProvider; + private final AtomicBoolean indexing = new AtomicBoolean(false); + private static final String MODEL_DIR_PROPERTY = "lsfusion.mcp.embedding.modelDir"; + private static final String MODEL_DIR_ENV = "LSFUSION_MCP_EMBEDDING_MODEL_DIR"; + private static final String DEFAULT_MODEL_DIR_NAME = ".mcp-model"; + private static final AtomicBoolean ONNX_SELF_TEST_DONE = new AtomicBoolean(false); + + private LocalMcpRagService(Project project) throws Exception { + this.project = project; + Path base = Path.of(project.getBasePath()); + Path indexPath = base.resolve(".mcp-index"); + Files.createDirectories(indexPath); + this.directory = FSDirectory.open(indexPath); + this.writer = new IndexWriter(directory, new IndexWriterConfig(new StandardAnalyzer())); + this.embeddingProvider = createEmbeddingProvider(); + } + + public static @NotNull LocalMcpRagService getInstance(@NotNull Project project) throws Exception { + LocalMcpRagService service = project.getUserData(KEY); + if (service != null) return service; + service = new LocalMcpRagService(project); + project.putUserData(KEY, service); + return service; + } + + public void indexProjectAsync() { + if (indexing.getAndSet(true)) return; + if (com.intellij.openapi.project.DumbService.isDumb(project)) { + com.intellij.openapi.project.DumbService.getInstance(project) + .runWhenSmart(this::indexProjectAsync); + indexing.set(false); + return; + } + com.intellij.openapi.application.ApplicationManager.getApplication().executeOnPooledThread(() -> { + long start = System.currentTimeMillis(); + try { + if (embeddingProvider == null) { + LOG.warn("MCP RAG index skipped: embedding provider not available"); + return; + } + LOG.info("MCP RAG index started"); + indexProject(); + LOG.info("MCP RAG index built in " + (System.currentTimeMillis() - start) + "ms"); + } catch (Throwable t) { + LOG.warn("MCP RAG index build failed", t); + } finally { + indexing.set(false); + } + }); + } + + public void indexProject() throws Exception { + if (embeddingProvider == null) return; + if (com.intellij.openapi.project.DumbService.isDumb(project)) { + LOG.info("MCP RAG index skipped: IDE is in dumb mode"); + return; + } + ReadAction.run(() -> { + int fileCount = 0; + for (LSFFile lsfFile : LSFFileUtils.getLsfFiles(ProjectScope.getAllScope(project))) { + indexFile(lsfFile); + fileCount++; + } + LOG.info("MCP RAG index pass completed: " + fileCount + " files"); + }); + writer.commit(); + } + + public void updateFile(@NotNull VirtualFile file) { + if (!file.getName().endsWith(".lsf")) return; + if (embeddingProvider == null) return; + if (com.intellij.openapi.project.DumbService.isDumb(project)) return; + try { + long start = System.currentTimeMillis(); + LOG.info("MCP RAG reindex start: " + file.getPath()); + ReadAction.run(() -> { + PsiFile psi = PsiManager.getInstance(project).findFile(file); + if (psi instanceof LSFFile lsfFile) { + indexFile(lsfFile); + } + }); + writer.commit(); + LOG.info("MCP RAG reindex done: " + file.getPath() + " in " + (System.currentTimeMillis() - start) + "ms"); + } catch (Throwable t) { + LOG.warn("MCP RAG updateFile failed: " + file.getPath(), t); + } + } + + public void deleteFile(@NotNull VirtualFile file) { + if (!file.getName().endsWith(".lsf")) return; + try { + writer.deleteDocuments(new Term(FIELD_FILE, file.getPath())); + writer.commit(); + } catch (Throwable t) { + LOG.warn("MCP RAG deleteFile failed: " + file.getPath(), t); + } + } + + public @NotNull List search(@NotNull String queryText, int topK) { + if (queryText.isBlank()) return List.of(); + if (embeddingProvider == null) return List.of(); + try { + long start = System.currentTimeMillis(); + long vectorStart = System.currentTimeMillis(); + LOG.info("MCP RAG query vectorization start: length=" + queryText.length()); + float[] vector = embeddingProvider.embed(queryText); + LOG.info("MCP RAG query vectorization done: dim=" + vector.length + + ", " + (System.currentTimeMillis() - vectorStart) + "ms"); + if (vector.length == 0) return List.of(); + try (IndexReader reader = DirectoryReader.open(writer)) { + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs hits = searcher.search(new MatchAllDocsQuery(), Math.max(1, reader.numDocs())); + List out = new ArrayList<>(); + List scored = new ArrayList<>(hits.scoreDocs.length); + for (ScoreDoc hit : hits.scoreDocs) { + Document doc = searcher.storedFields().document(hit.doc); + String location = doc.get(FIELD_ID); + BytesRef br = doc.getBinaryValue(FIELD_VECTOR); + byte[] raw = br != null ? Arrays.copyOfRange(br.bytes, br.offset, br.offset + br.length) : null; + if (location == null || raw == null) continue; + float[] vec = decodeVector(raw); + float score = dot(vector, vec); + scored.add(new ScoredId(location, score)); + } + scored.sort((a, b) -> Float.compare(b.score, a.score)); + for (int i = 0; i < Math.min(topK, scored.size()); i++) { + out.add(scored.get(i).id); + } + LOG.info("MCP RAG search: " + out.size() + " hits in " + (System.currentTimeMillis() - start) + "ms"); + return out; + } + } catch (Throwable t) { + LOG.warn("MCP RAG search failed", t); + return List.of(); + } + } + + private void indexFile(@NotNull LSFFile file) { + VirtualFile vf = file.getVirtualFile(); + if (vf == null) return; + if (embeddingProvider == null) return; + try { + long start = System.currentTimeMillis(); + LOG.info("MCP RAG indexFile start: " + vf.getPath()); + writer.deleteDocuments(new Term(FIELD_FILE, vf.getPath())); + Collection decls = LSFMCPDeclaration.getMCPDeclarations(file); + for (LSFMCPDeclaration decl : decls) { + indexStatement(file, decl); + } + LOG.info("MCP RAG indexFile done: " + vf.getPath() + " (" + decls.size() + " items) in " + + (System.currentTimeMillis() - start) + "ms"); + } catch (Throwable t) { + LOG.warn("MCP RAG indexFile failed: " + vf.getPath(), t); + } + } + + private void indexStatement(@NotNull LSFFile file, @NotNull LSFMCPDeclaration decl) throws Exception { + String location = MCPSearchUtils.getLocationByStatement(decl); + if (location == null) return; + float[] vector = embeddingProvider.embed(buildText(decl)); + if (vector.length == 0) return; + Document doc = new Document(); + doc.add(new StringField(FIELD_ID, location, Field.Store.YES)); + doc.add(new StringField(FIELD_FILE, file.getVirtualFile().getPath(), Field.Store.YES)); + doc.add(new StringField(FIELD_MODULE, file.getModuleDeclaration() != null ? file.getModuleDeclaration().getDeclName() : "", Field.Store.YES)); + String typeName; + try { + typeName = decl.getMCPType().apiName; + } catch (Throwable t) { + typeName = "unknown"; + } + doc.add(new StringField(FIELD_TYPE, typeName, Field.Store.YES)); + doc.add(new TextField(FIELD_CODE, decl.getText(), Field.Store.NO)); + doc.add(new StoredField(FIELD_VECTOR, encodeVector(vector))); + writer.addDocument(doc); + } + + private static String buildText(LSFMCPDeclaration decl) { + StringBuilder sb = new StringBuilder(); + sb.append(decl.getMCPType().apiName).append(' '); + for (var d : LSFMCPDeclaration.getNameDeclarations(decl)) { + if (d.getDeclName() != null) sb.append(d.getDeclName()).append(' '); + if (d.getCanonicalName() != null) sb.append(d.getCanonicalName()).append(' '); + } + sb.append(decl.getText()); + return sb.toString(); + } + + private static byte[] encodeVector(float[] v) { + ByteBuffer buf = ByteBuffer.allocate(v.length * 4).order(ByteOrder.LITTLE_ENDIAN); + for (float x : v) buf.putFloat(x); + return buf.array(); + } + + private static float[] decodeVector(byte[] raw) { + ByteBuffer buf = ByteBuffer.wrap(raw).order(ByteOrder.LITTLE_ENDIAN); + float[] out = new float[raw.length / 4]; + for (int i = 0; i < out.length; i++) out[i] = buf.getFloat(); + return out; + } + + private static float dot(float[] a, float[] b) { + int n = Math.min(a.length, b.length); + float sum = 0f; + for (int i = 0; i < n; i++) sum += a[i] * b[i]; + return sum; + } + + private static final class ScoredId { + final String id; + final float score; + + private ScoredId(String id, float score) { + this.id = id; + this.score = score; + } + } + + private @Nullable EmbeddingProvider createEmbeddingProvider() { + LOG.info("MCP ONNX env: os.name=" + System.getProperty("os.name") + + ", os.arch=" + System.getProperty("os.arch") + + ", java.version=" + System.getProperty("java.version") + + ", java.io.tmpdir=" + System.getProperty("java.io.tmpdir") + + ", onnxruntime.native.path=" + System.getProperty("onnxruntime.native.path") + + ", onnxruntime.native.onnxruntime.path=" + System.getProperty("onnxruntime.native.onnxruntime.path") + + ", onnxruntime.native.onnxruntime4j_jni.path=" + System.getProperty("onnxruntime.native.onnxruntime4j_jni.path")); + String modelDir = System.getProperty(MODEL_DIR_PROPERTY); + if (modelDir == null || modelDir.isBlank()) { + modelDir = System.getenv(MODEL_DIR_ENV); + } + if (modelDir == null || modelDir.isBlank()) { + String basePath = project.getBasePath(); + if (basePath != null) { + Path fallback = Path.of(basePath).resolve(DEFAULT_MODEL_DIR_NAME); + if (Files.isDirectory(fallback)) { + modelDir = fallback.toString(); + } + } + } + if (modelDir == null || modelDir.isBlank()) { + String cwd = System.getProperty("user.dir"); + if (cwd != null && !cwd.isBlank()) { + Path fallback = Path.of(cwd).resolve(DEFAULT_MODEL_DIR_NAME); + if (Files.isDirectory(fallback)) { + modelDir = fallback.toString(); + } + } + } + + if (modelDir == null || modelDir.isBlank()) { + LOG.warn("MCP embedding model dir not set: set -D" + MODEL_DIR_PROPERTY + + " or env " + MODEL_DIR_ENV + " or place model under /" + DEFAULT_MODEL_DIR_NAME + + " or /" + DEFAULT_MODEL_DIR_NAME); + return null; + } + try { + LOG.info("MCP embedding model dir: " + modelDir); + if (ONNX_SELF_TEST_DONE.compareAndSet(false, true)) { + try { + ai.onnxruntime.OrtEnvironment.getEnvironment(); + LOG.info("MCP ONNX self-test: OK"); + } catch (Throwable t) { + LOG.warn("MCP ONNX self-test failed: " + t.getMessage(), t); + } + } + EmbeddingProvider provider = new OnnxEmbeddingProvider(Path.of(modelDir)); + LOG.info("MCP embedding provider init: OK"); + return provider; + } catch (Throwable t) { + String message = t.getMessage(); + if (message != null && !message.isBlank()) { + LOG.warn("MCP embedding provider init failed: " + message); + } + if (t instanceof UnsatisfiedLinkError) { + LOG.warn("MCP embedding provider init failed: native ONNX runtime could not load. " + + "On Windows this is usually missing MSVC runtime. " + + "Install Microsoft Visual C++ 2015-2022 Redistributable (x64) and restart IDE."); + } + LOG.warn("MCP embedding provider init failed", t); + return null; + } + } +} diff --git a/src/com/lsfusion/mcp/MCPSearchUtils.java b/src/com/lsfusion/mcp/MCPSearchUtils.java index b9e9cae2..61b6ab49 100644 --- a/src/com/lsfusion/mcp/MCPSearchUtils.java +++ b/src/com/lsfusion/mcp/MCPSearchUtils.java @@ -14,8 +14,6 @@ import com.intellij.psi.PsiReference; import com.intellij.psi.search.GlobalSearchScope; import com.intellij.psi.search.ProjectScope; -import com.intellij.psi.search.PsiSearchHelper; -import com.intellij.psi.search.UsageSearchContext; import com.intellij.psi.util.PsiTreeUtil; import com.intellij.util.MergeQuery; import com.intellij.util.Processor; @@ -85,6 +83,7 @@ private static int getPriority(LSFMCPDeclaration.ElementType type) { public static final int DEFAULT_MAX_SYMBOLS = 200_000; // fail-safe public static final int DEFAULT_MIN_SYMBOLS = 1_000; // fail-safe public static final int DEFAULT_TIMEOUT_SECS = 10; // default 10 seconds + public static final int DEFAULT_RAG_TOP_K = 200; // TODO: tune private static final class SelectedStatement { private final @NotNull LSFMCPDeclaration decl; @@ -112,6 +111,8 @@ private SelectedStatement(@NotNull LSFMCPDeclaration decl, GlobalSearchScope sea * requiredModules: true, // optional (default: true); if true, include REQUIRE-d modules for modules in `modules` * name: "cust,Order", // optional, CSV; filters by element name * contains: "(?i)cust.*", // optional, CSV; filters by element code + * query: "cust,Order", // optional, CSV; semantic query for local RAG (vector search) + * useVectorSearch: false, // optional; if true, use local vector search for `query` * elementTypes: "class,property,action", // optional, CSV * classes: "MyNS.MyClass, MyOtherNS.OtherClass", // optional, CSV canonical names * relatedElements: "property:MyNS.myProp[MyNS.MyClass], MyModule(10:5)", // optional, CSV @@ -361,6 +362,26 @@ private static boolean submitRelatedTasks(Map rela return !related.isEmpty(); } + private static boolean submitRagTask(Project project, GlobalSearchScope searchScope, String queryText, Processor processor, TaskSubmitter submit) { + submit.submit(1, () -> { + try { + LocalMcpRagService rag = LocalMcpRagService.getInstance(project); + List locations = rag.search(queryText, DEFAULT_RAG_TOP_K); + if (locations.isEmpty()) return; + ReadAction.run(() -> { + for (String location : locations) { + LSFMCPDeclaration st = getStatementByLocation(project, searchScope, location); + if (st == null) continue; + if (!processor.process(st)) return; + } + }); + } catch (Exception e) { + LOG.warn("Error executing MCP RAG search", e); + } + }); + return true; + } + private static void submitFileTasks(GlobalSearchScope searchScope, Processor processor, TaskSubmitter submit) { for (LSFFile lsfFile : ReadAction.compute(() -> LSFFileUtils.getLsfFiles(searchScope))) { submit.submit(5, () -> ReadAction.run(() -> { @@ -371,19 +392,6 @@ private static void submitFileTasks(GlobalSearchScope searchScope, Processor filters, Processor processor, TaskSubmitter submit) { - if (filters.isEmpty()) return false; - boolean fullyStreamable = true; - for (NameFilter nf : filters) { - if (nf.isWordStreamable()) { - submit.submit(5, () -> ReadAction.run(() -> streamWord(project, nf, processor, searchScope))); - } else { - fullyStreamable = false; - } - } - return fullyStreamable; - } - private static @NotNull GlobalSearchScope run(@NotNull Project project, @NotNull JSONObject query, @NotNull Set seen, @@ -393,6 +401,7 @@ private static boolean submitWordTasks(Project project, GlobalSearchScope search GlobalSearchScope searchScope = ReadAction.compute(() -> buildSearchScope(project, query.optString("modules"), query.optString("scope"), query.optBoolean("requiredModules", true))); List nameFilters = parseMatchersCsv(query.optString("name")); List containsFilters = parseMatchersCsv(query.optString("contains")); + String rawQuery = query.optString("query", "").trim(); Set elementTypes = parseElementTypes(query.optString("elementTypes")); Set classDecls = ReadAction.compute(() -> parseClasses(project, searchScope, query.optString("classes"))); @@ -401,11 +410,14 @@ private static boolean submitWordTasks(Project project, GlobalSearchScope search // Shared processor that applies all filters and returns false to stop the current iteration final Processor processor = createSearchProcessor(state, seen, nameFilters, containsFilters, elementTypes, classDecls, related, searchScope, relatedCache); - // Track which blocks are fully streamable while assembling iterations. - // Name/code filters are considered fully streamable only if ALL matchers are "word-only" with length >= 3. - // Name/code-based iterations (only for word length >= 3) - boolean nameFullyStreamable = submitWordTasks(project, searchScope, nameFilters, processor, submit); - boolean containsFullyStreamable = submitWordTasks(project, searchScope, containsFilters, processor, submit); + boolean ragSubmitted = false; + if (query.optBoolean("useVectorSearch", false) && !rawQuery.isEmpty()) { + ragSubmitted = submitRagTask(project, searchScope, rawQuery, processor, submit); + } + + if (ragSubmitted) { + return searchScope; + } // Element-type iterations (only for index-backed types and only if elementTypes filter provided) boolean typesFullyStreamable = submitTypeTasks(project, searchScope, elementTypes, processor, submit); @@ -417,7 +429,7 @@ private static boolean submitWordTasks(Project project, GlobalSearchScope search boolean relatedFullyStreamable = submitRelatedTasks(related, searchScope, processor, submit); // Per-file iterations (fallback) — run only if no block is fully streamable - if (!nameFullyStreamable && !containsFullyStreamable && !classesFullyStreamable && !relatedFullyStreamable && !typesFullyStreamable) { + if (!classesFullyStreamable && !relatedFullyStreamable && !typesFullyStreamable) { submitFileTasks(searchScope, processor, submit); } @@ -432,15 +444,6 @@ private static boolean isOnlyPropertiesClassesActions(Set processor, GlobalSearchScope searchScope) { - PsiSearchHelper helper = PsiSearchHelper.getInstance(project); - helper.processElementsWithWord((element, offsetInElement) -> processStatement(element, processor), - searchScope, - nf.word, - (short)(UsageSearchContext.IN_CODE | UsageSearchContext.IN_FOREIGN_LANGUAGES | UsageSearchContext.IN_COMMENTS), - true); - } - private static boolean isTimedOut(long deadlineMillis) { return System.currentTimeMillis() > deadlineMillis; } @@ -774,7 +777,7 @@ private static Collection getStatementsFromJson(Project proje // Expected format: (:) e.g. MyModule(10:5) // line and symbolInLine are 1-based - private static LSFMCPDeclaration getStatementByLocation(Project project, GlobalSearchScope scope, String location) { + public static LSFMCPDeclaration getStatementByLocation(Project project, GlobalSearchScope scope, String location) { if (location == null) return null; String s = location.trim(); if (s.isEmpty()) return null; @@ -834,7 +837,7 @@ private static LSFMCPDeclaration getStatementByLocation(Project project, GlobalS return null; } - private static String getLocationByStatement(LSFMCPStatement stmt) { + public static String getLocationByStatement(LSFMCPStatement stmt) { if (stmt == null) return null; Pair location = LSFMCPStatement.getLocation(stmt); @@ -1262,10 +1265,6 @@ private static boolean matchesNameFilters(LSFMCPDeclaration stmt, List= 3 && regex == null; - } NameFilter(String word, Pattern regex) { this.word = word; diff --git a/src/com/lsfusion/mcp/McpServerService.java b/src/com/lsfusion/mcp/McpServerService.java index fb19b6bc..7fe1a7c9 100644 --- a/src/com/lsfusion/mcp/McpServerService.java +++ b/src/com/lsfusion/mcp/McpServerService.java @@ -254,6 +254,15 @@ private JSONObject buildFindElementsToolDescriptor() { .put("type", "string") .put("description", "Element name filter as CSV (comma-separated). Word if valid ID, else Java regex.")) + .put("query", new JSONObject() + .put("type", "string") + .put("description", + "Semantic query for local RAG (vector search). CSV allowed.")) + .put("useVectorSearch", new JSONObject() + .put("type", "boolean") + .put("default", false) + .put("description", + "If true, use local vector search for `query`. If false, use standard filters (names/contains).")) .put("contains", new JSONObject() .put("type", "string") .put("description", @@ -285,7 +294,7 @@ private JSONObject buildFindElementsToolDescriptor() { .put("moreFilters", new JSONObject() .put("type", "string") .put("description", - "Additional filter objects of the same structure as the root. JSON array string (e.g. `[{\"names\":\"Foo\", \"modules\" : \"MyModule\"},{\"names\":\"Bar\"}]`). Results are merged (OR).")) + "Additional filter objects of the same structure as the root. JSON array string (e.g. `[{\"name\":\"Foo\", \"contains\":\"bar\", \"modules\" : \"MyModule\"},{\"query\":\"Foo\", \"useVectorSearch\":true}]`). Results are merged (OR).")) .put("minSymbols", new JSONObject() .put("type", "integer") .put("minimum", 0) diff --git a/src/com/lsfusion/mcp/McpToolset.kt b/src/com/lsfusion/mcp/McpToolset.kt index 2797345b..098783a6 100644 --- a/src/com/lsfusion/mcp/McpToolset.kt +++ b/src/com/lsfusion/mcp/McpToolset.kt @@ -200,6 +200,14 @@ class McpToolset : com.intellij.mcpserver.McpToolset { description = "Element code filter as CSV. Word if valid ID, else Java regex." ) contains: String? = null, + @McpDescription( + description = "Semantic query for local RAG (vector search). CSV allowed." + ) + query: String? = null, + @McpDescription( + description = "If true, use local vector search for `query`. If false, use standard filters (names/contains). Default: false." + ) + useVectorSearch: Boolean = false, @McpDescription( description = "Element type filter as CSV. Allowed values: `module`, `metacode`, `class`, `property`, `action`, `form`, `navigatorElement`, `window`, `group`, `table`, `event`, `calculatedEvent`, `constraint`, `index`." ) @@ -215,7 +223,7 @@ class McpToolset : com.intellij.mcpserver.McpToolset { @McpDescription(description = "Direction for ALL `relatedElements` seeds. Allowed values: `both`, `uses`, `used`. Default: `both`.") relatedDirection: String? = null, @McpDescription( - description = "Additional filter objects of the same structure as the root. JSON array string (e.g. `[{\"names\":\"Foo\", \"modules\" : \"MyModule\"},{\"names\":\"Bar\"}]`). Results are merged (OR)." + description = "Additional filter objects of the same structure as the root. JSON array string (e.g. `[{\"name\":\"Foo\", \"contains\":\"bar\", \"modules\" : \"MyModule\"},{\"query\":\"Foo\", \"useVectorSearch\":true}]`). Results are merged (OR)." ) moreFilters: String? = null, @McpDescription(description = "Best-effort minimum output size in JSON chars; server may append neighboring elements if too small (>= 0). Default: ${MCPSearchUtils.DEFAULT_MIN_SYMBOLS}.") @@ -235,24 +243,26 @@ class McpToolset : com.intellij.mcpserver.McpToolset { } try { - val query = JSONObject() - if (modules != null) query.put("modules", modules) - if (scope != null) query.put("scope", scope) - query.put("requiredModules", requiredModules) - if (names != null) query.put("name", names) - if (contains != null) query.put("contains", contains) - if (elementTypes != null) query.put("elementTypes", elementTypes) - if (classes != null) query.put("classes", classes) - if (relatedElements != null) query.put("relatedElements", relatedElements) - if (relatedDirection != null) query.put("relatedDirection", relatedDirection) - query.put("minSymbols", minSymbols) - query.put("maxSymbols", maxSymbols) - query.put("timeoutSeconds", timeoutSeconds) + val payload = JSONObject() + if (modules != null) payload.put("modules", modules) + if (scope != null) payload.put("scope", scope) + payload.put("requiredModules", requiredModules) + if (names != null) payload.put("name", names) + if (contains != null) payload.put("contains", contains) + if (query != null) payload.put("query", query) + if (useVectorSearch) payload.put("useVectorSearch", true) + if (elementTypes != null) payload.put("elementTypes", elementTypes) + if (classes != null) payload.put("classes", classes) + if (relatedElements != null) payload.put("relatedElements", relatedElements) + if (relatedDirection != null) payload.put("relatedDirection", relatedDirection) + payload.put("minSymbols", minSymbols) + payload.put("maxSymbols", maxSymbols) + payload.put("timeoutSeconds", timeoutSeconds) if (moreFilters != null && !moreFilters.isEmpty()) { - query.put("moreFilters", JSONArray(moreFilters)) + payload.put("moreFilters", JSONArray(moreFilters)) } - val result = MCPSearchUtils.findElements(project, query) + val result = MCPSearchUtils.findElements(project, payload) val jsonElement = json.parseToJsonElement(result.toString()) return json.decodeFromJsonElement(jsonElement) } catch (e: McpExpectedError) { diff --git a/src/com/lsfusion/mcp/OnnxEmbeddingProvider.java b/src/com/lsfusion/mcp/OnnxEmbeddingProvider.java new file mode 100644 index 00000000..a1cff99e --- /dev/null +++ b/src/com/lsfusion/mcp/OnnxEmbeddingProvider.java @@ -0,0 +1,172 @@ +package com.lsfusion.mcp; + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import com.intellij.openapi.diagnostic.Logger; +import org.jetbrains.annotations.NotNull; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; + +import java.nio.LongBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +/** + * ONNX Runtime embedding provider using e5-small. + * Model directory is expected to contain: + * - model.onnx + * - tokenizer.json (HuggingFace tokenizer) + */ +public final class OnnxEmbeddingProvider implements EmbeddingProvider { + private static final Logger LOG = Logger.getInstance(OnnxEmbeddingProvider.class); + + private final OrtEnvironment env; + private final OrtSession session; + private final HuggingFaceTokenizer tokenizer; + private final int dimension; + private final String inputIdsName; + private final String attentionMaskName; + private final String tokenTypeIdsName; + private final String outputName; + + public OnnxEmbeddingProvider(@NotNull Path modelDir) throws Exception { + Path modelPath = modelDir.resolve("model.onnx"); + Path tokenizerPath = modelDir.resolve("tokenizer.json"); + if (!Files.exists(modelPath) || !Files.exists(tokenizerPath)) { + throw new IllegalStateException("Missing model.onnx or tokenizer.json in " + modelDir); + } + + this.env = OrtEnvironment.getEnvironment(); + this.session = env.createSession(modelPath.toString(), new OrtSession.SessionOptions()); + // Ensure DJL can see its native resources via context classloader in IntelliJ. + ClassLoader prevCl = Thread.currentThread().getContextClassLoader(); + ClassLoader tokenizersCl = HuggingFaceTokenizer.class.getClassLoader(); + try { + if (tokenizersCl != null) { + Thread.currentThread().setContextClassLoader(tokenizersCl); + } + this.tokenizer = HuggingFaceTokenizer.newInstance(tokenizerPath); + } finally { + Thread.currentThread().setContextClassLoader(prevCl); + } + + this.inputIdsName = pickInputName("input_ids"); + this.attentionMaskName = pickInputName("attention_mask"); + this.tokenTypeIdsName = pickInputName("token_type_ids"); + this.outputName = pickOutputName(); + this.dimension = inferDimension(); + } + + @Override + public float[] embed(String text) throws Exception { + var encoding = tokenizer.encode(text); + long[] inputIds = encoding.getIds(); + long[] attentionMask = encoding.getAttentionMask(); + long[] tokenTypeIds = encoding.getTypeIds(); + + if (tokenTypeIds == null || tokenTypeIds.length == 0) { + tokenTypeIds = new long[inputIds.length]; + } + + long[] shape = new long[]{1, inputIds.length}; + try (OnnxTensor idsTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds), shape); + OnnxTensor maskTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMask), shape); + OnnxTensor typeTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(tokenTypeIds), shape)) { + + Map inputs = new HashMap<>(); + if (inputIdsName != null) inputs.put(inputIdsName, idsTensor); + if (attentionMaskName != null) inputs.put(attentionMaskName, maskTensor); + if (tokenTypeIdsName != null) inputs.put(tokenTypeIdsName, typeTensor); + + try (OrtSession.Result result = session.run(inputs)) { + OnnxValue value = result.get(outputName).orElseThrow(); + float[] vector = extractEmbedding(value, attentionMask); + normalize(vector); + return vector; + } + } + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public void close() throws Exception { + session.close(); + env.close(); + } + + private String pickInputName(String preferred) { + if (session.getInputNames().contains(preferred)) { + return preferred; + } + return session.getInputNames().stream().findFirst().orElse(null); + } + + private String pickOutputName() { + if (session.getOutputNames().contains("sentence_embedding")) { + return "sentence_embedding"; + } + return session.getOutputNames().stream().findFirst().orElseThrow(); + } + + private int inferDimension() throws OrtException { + var outputInfo = session.getOutputInfo().get(outputName).getInfo(); + if (outputInfo instanceof ai.onnxruntime.TensorInfo ti) { + long[] shape = ti.getShape(); + if (shape.length == 2 && shape[1] > 0) { + return (int) shape[1]; + } + if (shape.length == 3 && shape[2] > 0) { + return (int) shape[2]; + } + } + return 384; // fallback for e5-small + } + + private float[] extractEmbedding(OnnxValue value, long[] attentionMask) throws OrtException { + Object raw = value.getValue(); + if (raw instanceof float[][] vec2d) { + return vec2d[0]; + } + if (raw instanceof float[][][] vec3d) { + return meanPool(vec3d[0], attentionMask); + } + LOG.warn("Unexpected ONNX output type: " + raw.getClass().getName()); + return new float[dimension]; + } + + private static float[] meanPool(float[][] tokenEmbeds, long[] attentionMask) { + int dim = tokenEmbeds[0].length; + float[] out = new float[dim]; + float denom = 0f; + for (int i = 0; i < tokenEmbeds.length; i++) { + float mask = (attentionMask != null && i < attentionMask.length) ? attentionMask[i] : 1f; + if (mask <= 0) continue; + denom += mask; + float[] t = tokenEmbeds[i]; + for (int d = 0; d < dim; d++) { + out[d] += t[d] * mask; + } + } + if (denom > 0) { + for (int d = 0; d < dim; d++) out[d] /= denom; + } + return out; + } + + private static void normalize(float[] v) { + double sum = 0; + for (float x : v) sum += x * x; + double norm = Math.sqrt(sum); + if (norm == 0) return; + for (int i = 0; i < v.length; i++) v[i] /= norm; + } +}