Skip to content

Commit 2acab4c

Browse files
Merge pull request #97 from CHERRY-ui8/fix/quiz-refactor
refactor: implement QuizGenerator and refactor quiz_and_judge to standard operator
2 parents 6f5fb9c + c0faf3b commit 2acab4c

File tree

12 files changed

+175
-129
lines changed

12 files changed

+175
-129
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ source-roots=
100100

101101
# When enabled, pylint would attempt to guess common misconfiguration and emit
102102
# user-friendly hints instead of false-positive error messages.
103-
suggestion-mode=yes
103+
# suggestion-mode=yes
104104

105105
# Allow loading of arbitrary C extensions. Extensions are imported into the
106106
# active Python interpreter and may run arbitrary code.

graphgen/graphgen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
221221
self.graph_storage,
222222
self.rephrase_storage,
223223
max_samples,
224+
progress_bar=self.progress_bar,
224225
)
225226

226227
# TODO: assert trainee_llm_client is valid before judge

graphgen/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
AtomicGenerator,
55
CoTGenerator,
66
MultiHopGenerator,
7+
QuizGenerator,
78
VQAGenerator,
89
)
910
from .kg_builder import LightRAGKGBuilder, MMKGBuilder

graphgen/models/generator/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from .atomic_generator import AtomicGenerator
33
from .cot_generator import CoTGenerator
44
from .multi_hop_generator import MultiHopGenerator
5+
from .quiz_generator import QuizGenerator
56
from .vqa_generator import VQAGenerator
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Any
2+
3+
from graphgen.bases import BaseGenerator
4+
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
5+
from graphgen.utils import detect_main_language, logger
6+
7+
8+
class QuizGenerator(BaseGenerator):
9+
"""
10+
Quiz Generator rephrases given descriptions to create quiz questions.
11+
"""
12+
13+
@staticmethod
14+
def build_prompt(
15+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
16+
) -> str:
17+
"""
18+
Build prompt for rephrasing the description.
19+
:param batch: A tuple containing (nodes, edges) where nodes/edges
20+
contain description information
21+
:return: Prompt string
22+
"""
23+
# Extract description from batch
24+
# For quiz generator, we expect a special format where
25+
# the description is passed as the first node's description
26+
nodes, edges = batch
27+
if nodes:
28+
description = nodes[0][1].get("description", "")
29+
template_type = nodes[0][1].get("template_type", "TEMPLATE")
30+
elif edges:
31+
description = edges[0][2].get("description", "")
32+
template_type = edges[0][2].get("template_type", "TEMPLATE")
33+
else:
34+
raise ValueError("Batch must contain at least one node or edge with description")
35+
36+
return QuizGenerator.build_prompt_for_description(description, template_type)
37+
38+
@staticmethod
39+
def build_prompt_for_description(description: str, template_type: str = "TEMPLATE") -> str:
40+
"""
41+
Build prompt for rephrasing a single description.
42+
:param description: The description to rephrase
43+
:param template_type: Either "TEMPLATE" (same meaning) or "ANTI_TEMPLATE" (opposite meaning)
44+
:return: Prompt string
45+
"""
46+
language = detect_main_language(description)
47+
prompt = DESCRIPTION_REPHRASING_PROMPT[language][template_type].format(
48+
input_sentence=description
49+
)
50+
return prompt
51+
52+
@staticmethod
53+
def parse_rephrased_text(response: str) -> str:
54+
"""
55+
Parse the rephrased text from the response.
56+
:param response:
57+
:return:
58+
"""
59+
rephrased_text = response.strip().strip('"')
60+
logger.debug("Rephrased Text: %s", rephrased_text)
61+
return rephrased_text
62+
63+
@staticmethod
64+
def parse_response(response: str) -> Any:
65+
"""
66+
Parse the LLM response. For quiz generator, this returns the rephrased text.
67+
:param response: LLM response
68+
:return: Rephrased text
69+
"""
70+
return QuizGenerator.parse_rephrased_text(response)

graphgen/operators/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from .extract import extract_info
33
from .generate import generate_qas
44
from .init import init_llm
5-
from .judge import judge_statement
65
from .partition import partition_kg
7-
from .quiz import quiz
6+
from .quiz_and_judge import judge_statement, quiz
87
from .read import read_files
98
from .search import search_all
109
from .split import chunk_documents

graphgen/operators/generate/generate_qas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any
22

3+
import gradio as gr
4+
35
from graphgen.bases import BaseLLMWrapper
46
from graphgen.models import (
57
AggregatedGenerator,
@@ -19,7 +21,7 @@ async def generate_qas(
1921
]
2022
],
2123
generation_config: dict,
22-
progress_bar=None,
24+
progress_bar: gr.Progress = None,
2325
) -> list[dict[str, Any]]:
2426
"""
2527
Generate question-answer pairs based on nodes and edges.

graphgen/operators/quiz.py

Lines changed: 0 additions & 123 deletions
This file was deleted.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .judge import judge_statement
2+
from .quiz import quiz
File renamed without changes.

0 commit comments

Comments
 (0)