From 294e866a11c04846df0ab42938adbabfd4fe4057 Mon Sep 17 00:00:00 2001 From: Neko1t <2023141520040@stu.scu.edu.cn> Date: Wed, 22 Apr 2026 20:52:54 +0800 Subject: [PATCH] feat: enable MACER multi-step subgraph refinement and make num_steps configurable - Add RefinableMultiStepQueryEngine for MACER subgraph refinement in ToG3 inference - Add num_steps config field to RAGConfig (default: 3) - Add num_steps to config.yaml template - Fix OpenAICompatible.is_chat_model default to True for local runtime compatibility - Update README and run.sh --- README.md | 4 ++-- examples/ToG3/config.yaml | 3 ++- rag_factory/args.py | 1 + rag_factory/llms/openai_compatible.py | 2 +- run.sh | 3 +++ tog3_main.py | 13 +++++++++++-- 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index dbd783a..f33a59c 100644 --- a/README.md +++ b/README.md @@ -47,12 +47,12 @@ pip install -e . ## Usage ```bash -bash run.sh naive_rag/graph_rag/mm_rag +bash run.sh ToG3 ``` or ```bash -python main.py --config examples/graphrag/config.yaml +python main.py --config examples/ToG3/config.yaml ``` diff --git a/examples/ToG3/config.yaml b/examples/ToG3/config.yaml index 025726e..56eb5b6 100644 --- a/examples/ToG3/config.yaml +++ b/examples/ToG3/config.yaml @@ -39,7 +39,8 @@ rag: insert_community_nodes: True num_workers: 4 # 并行处理chunk的worker数 similarity_top_k: 10 # 检索到的top_k个节点 - stages: ["inference","evaluation"] + stages: ["create","inference","evaluation"] # graph_rag参数 max_paths_per_chunk: 2 # 每个chunk的最大path数, 也就是每个chunk抽取的max_knowledge_triplets max_cluster_size: 5 # 对graph进行聚类以获得commuities + num_steps: 3 # MACER 多步迭代的最大步数 diff --git a/rag_factory/args.py b/rag_factory/args.py index 26d9382..5c1966e 100644 --- a/rag_factory/args.py +++ b/rag_factory/args.py @@ -52,6 +52,7 @@ class RAGConfig: stages: List[str] = field(default_factory=lambda: ["create", "inference", "evaluation"]) max_paths_per_chunk: int = 10 max_cluster_size: int = 50 + num_steps: int = 3 @dataclass class Query: diff --git a/rag_factory/llms/openai_compatible.py b/rag_factory/llms/openai_compatible.py index 62be3ec..4a6ead1 100644 --- a/rag_factory/llms/openai_compatible.py +++ b/rag_factory/llms/openai_compatible.py @@ -104,7 +104,7 @@ class OpenAICompatible(OpenAI): description=LLMMetadata.model_fields["context_window"].description, ) is_chat_model: bool = Field( - default=False, + default=True, description=LLMMetadata.model_fields["is_chat_model"].description, ) is_function_calling_model: bool = Field( diff --git a/run.sh b/run.sh index 517a4d8..38d8a8a 100644 --- a/run.sh +++ b/run.sh @@ -18,6 +18,9 @@ elif [ "$solution" == "graph_rag" ]; then elif [ "$solution" == "mm_rag" ]; then echo "Starting MultiModalRAG example..." python main.py --config examples/multimodal_rag/config.yaml +elif [ "$solution" == "ToG3" ]; then + echo "Starting ToG3 example..." + python main.py --config examples/ToG3/config.yaml else echo "Unknown solution: $solution" exit 1 diff --git a/tog3_main.py b/tog3_main.py index a5a2adc..f67e1a4 100644 --- a/tog3_main.py +++ b/tog3_main.py @@ -34,6 +34,8 @@ from rag_factory.graph_constructor import GraphRAGConstructor from rag_factory.retrivers.graphrag_query_engine import GraphRAGQueryEngine +from rag_factory.retrivers.refinable_multistep_query_engine import RefinableMultiStepQueryEngine +from llama_index.core.indices.query.query_transform.base import StepDecomposeQueryTransform def read_args(config_path: Union[str, Path]) -> Tuple[DatasetConfig, LLMConfig, EmbeddingConfig, StorageConfig, RAGConfig]: @@ -332,8 +334,15 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> elif rag_config.solution == "multi_modal_rag": # TODO: Implement Multi-modal RAG solution raise NotImplementedError("Multi-modal RAG solution is not implemented yet.") - elif rag_config.solution == "tog3": - query_engine = index.as_query_engine() + elif rag_config.solution == "ToG3": + # 使用 RefinableMultiStepQueryEngine 实现 MACER 子图精化 + query_transform = StepDecomposeQueryTransform(llm=llm, verbose=True) + query_engine = RefinableMultiStepQueryEngine( + query_engine=index.as_query_engine(), + query_transform=query_transform, + index=index, + num_steps=rag_config.num_steps, + ) else: raise ValueError(f"Unsupported RAG solution: {rag_config.solution}")