Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ pip install -e .

## Usage
```bash
bash run.sh naive_rag/graph_rag/mm_rag
bash run.sh ToG3
```
or

```bash
python main.py --config examples/graphrag/config.yaml
python main.py --config examples/ToG3/config.yaml
```


Expand Down
3 changes: 2 additions & 1 deletion examples/ToG3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ rag:
insert_community_nodes: True
num_workers: 4 # 并行处理chunk的worker数
similarity_top_k: 10 # 检索到的top_k个节点
stages: ["inference","evaluation"]
stages: ["create","inference","evaluation"]
# graph_rag参数
max_paths_per_chunk: 2 # 每个chunk的最大path数, 也就是每个chunk抽取的max_knowledge_triplets
max_cluster_size: 5 # 对graph进行聚类以获得commuities
num_steps: 3 # MACER 多步迭代的最大步数
1 change: 1 addition & 0 deletions rag_factory/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class RAGConfig:
stages: List[str] = field(default_factory=lambda: ["create", "inference", "evaluation"])
max_paths_per_chunk: int = 10
max_cluster_size: int = 50
num_steps: int = 3

@dataclass
class Query:
Expand Down
2 changes: 1 addition & 1 deletion rag_factory/llms/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class OpenAICompatible(OpenAI):
description=LLMMetadata.model_fields["context_window"].description,
)
is_chat_model: bool = Field(
default=False,
default=True,
description=LLMMetadata.model_fields["is_chat_model"].description,
)
is_function_calling_model: bool = Field(
Expand Down
3 changes: 3 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ elif [ "$solution" == "graph_rag" ]; then
elif [ "$solution" == "mm_rag" ]; then
echo "Starting MultiModalRAG example..."
python main.py --config examples/multimodal_rag/config.yaml
elif [ "$solution" == "ToG3" ]; then
echo "Starting ToG3 example..."
python main.py --config examples/ToG3/config.yaml
else
echo "Unknown solution: $solution"
exit 1
Expand Down
13 changes: 11 additions & 2 deletions tog3_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

from rag_factory.graph_constructor import GraphRAGConstructor
from rag_factory.retrivers.graphrag_query_engine import GraphRAGQueryEngine
from rag_factory.retrivers.refinable_multistep_query_engine import RefinableMultiStepQueryEngine
from llama_index.core.indices.query.query_transform.base import StepDecomposeQueryTransform


def read_args(config_path: Union[str, Path]) -> Tuple[DatasetConfig, LLMConfig, EmbeddingConfig, StorageConfig, RAGConfig]:
Expand Down Expand Up @@ -332,8 +334,15 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") ->
elif rag_config.solution == "multi_modal_rag":
# TODO: Implement Multi-modal RAG solution
raise NotImplementedError("Multi-modal RAG solution is not implemented yet.")
elif rag_config.solution == "tog3":
query_engine = index.as_query_engine()
elif rag_config.solution == "ToG3":
# 使用 RefinableMultiStepQueryEngine 实现 MACER 子图精化
query_transform = StepDecomposeQueryTransform(llm=llm, verbose=True)
query_engine = RefinableMultiStepQueryEngine(
query_engine=index.as_query_engine(),
query_transform=query_transform,
index=index,
num_steps=rag_config.num_steps,
)
else:
raise ValueError(f"Unsupported RAG solution: {rag_config.solution}")

Expand Down