diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java index 7e845de..f31cf4a 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java @@ -1,5 +1,7 @@ package io.github.timemachinelab.controller; +import io.github.timemachinelab.core.qatree.QaTree; +import io.github.timemachinelab.core.qatree.QaTreeDomain; import io.github.timemachinelab.core.session.application.ConversationService; import io.github.timemachinelab.core.session.application.MessageProcessingService; import io.github.timemachinelab.core.session.application.SessionManagementService; @@ -11,6 +13,7 @@ import io.github.timemachinelab.entity.resp.ApiResult; import io.github.timemachinelab.entity.resp.RetryResponse; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.validation.annotation.Validated; @@ -20,6 +23,7 @@ import javax.annotation.Resource; import javax.validation.Valid; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -44,7 +48,9 @@ public class UserInteractionController { @Resource private SessionManagementService sessionManagementService; private final Map sseEmitters = new ConcurrentHashMap<>(); - + @Autowired + private QaTreeDomain qaTreeDomain; + /** * 建立SSE连接 */ @@ -88,8 +94,8 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess // 根据会话状态返回nodeId if (isNewSession) { // 新会话返回根节点ID - connectionData.put("nodeId", "root"); - log.info("新会话返回根节点ID: root - 会话: {}", sessionId); + connectionData.put("nodeId", "1"); + log.info("新会话返回根节点ID: 1 - 会话: {}", sessionId); } else if (session.getQaTree() != null && session.getQaTree().getRoot() != null) { // 已存在会话,返回根节点ID(因为qaTree只有根节点) String rootNodeId = session.getQaTree().getRoot().getId(); @@ -105,8 +111,8 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess } } else { // 兜底情况,返回根节点ID - connectionData.put("nodeId", "root"); - log.info("兜底返回根节点ID: root - 会话: {}", sessionId); + connectionData.put("nodeId", "1"); + log.info("兜底返回根节点ID: 1 - 会话: {}", sessionId); } emitter.send(SseEmitter.event() @@ -219,8 +225,8 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe return ResponseEntity.badRequest().body("现有会话必须提供nodeId"); } log.info("新建会话的第一个问题 - 会话: {}", session.getSessionId()); - } else if ("root".equals(nodeId)) { - // nodeId为'root',表示这是根节点的回答 + } else if ("1".equals(nodeId)) { + // nodeId为'1',表示这是根节点的回答 if (session.getQaTree() == null || session.getQaTree().getRoot() == null) { log.info("根节点回答,但qaTree未初始化 - 会话: {}", session.getSessionId()); // 允许继续处理,后续会创建qaTree @@ -242,6 +248,29 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe return ResponseEntity.badRequest().body("答案格式不正确"); } + QaTree qaTree = session.getQaTree(); + + // 根据问题类型获取正确的答案数据 + Object answerData; + switch (request.getQuestionType().toLowerCase()) { + case "input": + answerData = request.getInputAnswer(); + break; + case "single": + case "multi": + answerData = request.getChoiceAnswer(); + break; + case "form": + answerData = request.getFormAnswer(); + break; + default: + log.warn("未知的问题类型: {}", request.getQuestionType()); + answerData = request.getAnswerString(); + break; + } + + qaTreeDomain.updateNodeAnswer(qaTree, request.getNodeId(), answerData); + // 4. 处理答案并转换为消息 String processedMessage = messageProcessingService.preprocessMessage( null, // 没有额外的原始消息 @@ -256,6 +285,7 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe response -> sendSseMessage(session.getSessionId(), response) ); + return ResponseEntity.ok("答案处理成功"); } catch (Exception e) { @@ -266,6 +296,7 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe /** * 通过SSE发送消息给客户端 + * 在AI回复时创建QA节点,填入question,answer留空等用户提交后再更新 * * @param sessionId 会话ID * @param response 消息响应对象 @@ -274,13 +305,63 @@ private void sendSseMessage(String sessionId, QuestionGenerationOperation.Questi SseEmitter emitter = sseEmitters.get(sessionId); if (emitter != null) { try { + String currentNodeId = null; + + // 1. 先将AI生成的新问题添加到QaTree(只填入question,answer留空) + ConversationSession session = sessionManagementService.getSessionById(sessionId); + if (session != null && session.getQaTree() != null && response.getQuestion() != null) { + // 使用QaTreeDomain添加新节点,answer字段会自动为空 + // appendNode方法内部会调用session.getNextNodeId()获取新节点ID + QaTree qaTree = qaTreeDomain.appendNode( + session.getQaTree(), + response.getParentId(), + response.getQuestion(), + session + ); + + // 获取刚刚创建的节点ID(当前计数器的值) + currentNodeId = String.valueOf(session.getNodeIdCounter().get()); + + log.info("AI问题已添加到QaTree - 会话: {}, 父节点: {}, 新节点ID: {}, 问题类型: {}", + sessionId, response.getParentId(), currentNodeId, response.getQuestion().getType()); + } else { + log.warn("无法添加问题到QaTree - 会话: {}, session存在: {}, qaTree存在: {}, question存在: {}", + sessionId, session != null, + session != null && session.getQaTree() != null, + response.getQuestion() != null); + } + + // 2. 创建修改后的响应对象,包含currentNodeId和parentNodeId + Map modifiedResponse = new HashMap<>(); + modifiedResponse.put("question", response.getQuestion()); + modifiedResponse.put("currentNodeId", currentNodeId != null ? currentNodeId : response.getParentId()); + modifiedResponse.put("parentNodeId", response.getParentId()); + + // 3. 发送SSE消息给前端 emitter.send(SseEmitter.event() .name("message") - .data(response)); - log.info("SSE消息发送成功 - 会话: {}, 消息: {}", sessionId, response); + .data(modifiedResponse)); + log.info("SSE消息发送成功 - 会话: {}, 当前节点ID: {}", sessionId, currentNodeId); } catch (IOException e) { log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, e.getMessage()); sseEmitters.remove(sessionId); + } catch (Exception e) { + log.error("添加问题到QaTree失败 - 会话: {}, 错误: {}", sessionId, e.getMessage()); + // 即使QaTree更新失败,仍然发送SSE消息给前端 + try { + Map fallbackResponse = new HashMap<>(); + fallbackResponse.put("question", response.getQuestion()); + fallbackResponse.put("currentNodeId", response.getParentId()); // 使用parentId作为fallback + fallbackResponse.put("parentNodeId", response.getParentId()); + + emitter.send(SseEmitter.event() + .name("message") + .data(fallbackResponse)); + log.info("SSE消息发送成功(QaTree更新失败但消息已发送) - 会话: {}", sessionId); + } catch (IOException ioException) { + log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, ioException.getMessage()); + sseEmitters.remove(sessionId); + } } } else { log.warn("SSE连接不存在 - 会话: {}", sessionId); diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java index 5ffa0be..1ecccf3 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java @@ -24,14 +24,13 @@ public QaTree createTree(String userStartQuestion) { /** * 使用ConversationSession的自增ID创建QaTree - * @param userStartQuestion 用户开始问题 + * @param question 用户开始问题 * @param session 会话对象,用于获取自增ID * @return 创建的QaTree */ - public QaTree createTree(String userStartQuestion, ConversationSession session) { + public QaTree createTree(String question, ConversationSession session) { InputQuestion startQA = new InputQuestion(); - startQA.setQuestion(userStartQuestion); - startQA.setAnswer(userStartQuestion); + startQA.setQuestion(question); // 使用会话的自增ID创建根节点 String rootNodeId = session.getNextNodeId(); QaTreeNode startNode = new QaTreeNode(startQA, rootNodeId); diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java index d00d7d4..ac390ff 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java @@ -119,7 +119,7 @@ public ConversationSession createNewSession(String userId) { ConversationSession session = new ConversationSession(userId, newSessionId, null); // 使用会话的自增ID创建QaTree,确保根节点ID=1 - QaTree tree = qaTreeDomain.createTree("default", session); + QaTree tree = qaTreeDomain.createTree("你好,我有什么可以帮你?", session); // 设置QaTree到会话中 session.setQaTree(tree); diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java index b132483..5b197d7 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java @@ -71,8 +71,8 @@ private void updateQaTreeWithAnswer(ConversationSession session, UnifiedAnswerRe } String nodeId = request.getNodeId(); - // 如果nodeId为'root',使用根节点ID - if ("root".equals(nodeId) && qaTree.getRoot() != null) { + // 如果nodeId为'1'(根节点),使用根节点ID + if ("1".equals(nodeId) && qaTree.getRoot() != null) { nodeId = qaTree.getRoot().getId(); } diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/domain/entity/ConversationSession.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/domain/entity/ConversationSession.java index 18ba489..29fa0b2 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/domain/entity/ConversationSession.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/domain/entity/ConversationSession.java @@ -1,17 +1,26 @@ package io.github.timemachinelab.core.session.domain.entity; import io.github.timemachinelab.core.qatree.*; +import lombok.Data; import lombok.Getter; +import lombok.Setter; import java.time.LocalDateTime; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; -@Getter +@Data public class ConversationSession { private final String sessionId; private final String userId; + /** + * -- SETTER -- + * 设置QaTree(仅用于初始化) + * + * @param qaTree QA树对象 + */ + @Setter private QaTree qaTree; // 移除final,允许后续设置 private final LocalDateTime createTime; private LocalDateTime updateTime; @@ -36,10 +45,10 @@ public String getNextNodeId() { } /** - * 设置QaTree(仅用于初始化) - * @param qaTree QA树对象 + * 获取节点ID计数器 + * @return 节点ID计数器 */ - public void setQaTree(QaTree qaTree) { - this.qaTree = qaTree; + public AtomicInteger getNodeIdCounter() { + return nodeIdCounter; } } \ No newline at end of file diff --git a/prompto-lab-ui/src/components/Chat/AIChatPage.vue b/prompto-lab-ui/src/components/Chat/AIChatPage.vue index c1fae86..69559b2 100644 --- a/prompto-lab-ui/src/components/Chat/AIChatPage.vue +++ b/prompto-lab-ui/src/components/Chat/AIChatPage.vue @@ -250,22 +250,22 @@ const handleSSEMessage = (response: any) => { console.log('会话已建立:', session.value) - // 后端总是会返回nodeId,新会话返回'root',已存在会话返回实际的nodeId + // 后端总是会返回nodeId,新会话返回'1',已存在会话返回实际的nodeId if (response.nodeId) { currentNodeId.value = response.nodeId console.log('会话节点ID:', response.nodeId) // 如果是根节点,初始化根节点 - if (response.nodeId === 'root') { + if (response.nodeId === '1') { const rootNode: ConversationNode = { - id: 'root', + id: '1', content: '您好!我是AI助手,有什么可以帮助您的吗?', type: 'assistant', timestamp: new Date(), children: [], isActive: true } - conversationTree.value.set('root', rootNode) + conversationTree.value.set('1', rootNode) } } @@ -295,14 +295,52 @@ const handleSSEMessage = (response: any) => { // 这是新的问题格式 currentQuestion.value = response.question - // 更新当前节点ID为问题的parentId - if (response.parentId) { - currentNodeId.value = response.parentId - console.log('更新当前节点ID为:', response.parentId) + // 更新当前节点ID为新创建的问题节点ID + if (response.currentNodeId) { + // 创建问题节点并添加到对话树 + const questionContent = `${response.question.question}${response.question.desc ? '\n' + response.question.desc : ''}` + + const questionNode: ConversationNode = { + id: response.currentNodeId, + content: questionContent, + type: 'assistant', + timestamp: new Date(), + parentId: response.parentNodeId, + children: [], + isActive: true + } + + // 更新父节点的children数组 + if (response.parentNodeId) { + const parentNode = conversationTree.value.get(response.parentNodeId) + if (parentNode) { + // 将父节点的其他子节点设为非活跃状态 + parentNode.children.forEach(childId => { + const childNode = conversationTree.value.get(childId) + if (childNode) { + setNodeAndDescendantsInactive(childId) + } + }) + parentNode.children.push(response.currentNodeId) + } + } + + // 添加新问题节点到对话树 + conversationTree.value.set(response.currentNodeId, questionNode) + currentNodeId.value = response.currentNodeId + console.log('更新当前节点ID为:', response.currentNodeId) + + // 在聊天界面显示问题内容 + addAIMessage(response.currentNodeId, questionContent) + } + + // 记录父节点ID,用于后续构建树形关系图 + if (response.parentNodeId) { + console.log('父节点ID:', response.parentNodeId) } isLoading.value = false - console.log('收到新格式问题:', response.question, '父节点ID:', response.parentId) + console.log('收到新格式问题:', response.question, '当前节点ID:', response.currentNodeId, '父节点ID:', response.parentNodeId) return } @@ -639,10 +677,13 @@ const handleSubmitAnswer = async (answerData: any) => { isLoading.value = true try { + // 保存当前问题节点ID,用于后端验证 + const questionNodeId = currentNodeId.value + // 构建统一答案请求,必须包含sessionId和正确的nodeId const request: UnifiedAnswerRequest = { sessionId: session.value.sessionId, // 必需的sessionId - nodeId: currentNodeId.value, // 当前节点ID,用于后端验证 + nodeId: questionNodeId, // 问题节点ID,用于后端验证 questionType: currentQuestion.value.type, answer: answerData, userId: session.value.userId // 必需的userId @@ -708,10 +749,11 @@ const handleSubmitAnswer = async (answerData: any) => { } conversationTree.value.set(userNodeId, userNode) - currentNodeId.value = userNodeId + // 不更新currentNodeId为用户节点ID,保持为问题节点ID直到收到新问题 + // currentNodeId.value = userNodeId - // 清除当前问题状态 - currentQuestion.value = null + // 不清除当前问题状态,保持显示直到收到新问题 + // currentQuestion.value = null toast.success({ title: '提交成功', @@ -799,13 +841,19 @@ const handleBranchDeleted = (nodeId: string) => { deleteNodeAndDescendants(nodeId) if (!conversationTree.value.has(currentNodeId.value)) { - let newCurrentId = 'root' + // 动态查找根节点(没有parentId的节点) + let rootNodeId = '' + let newCurrentId = '' conversationTree.value.forEach((node, id) => { - if (node.isActive && id !== 'root') { + if (!node.parentId) { + rootNodeId = id + } + if (node.isActive) { newCurrentId = id } }) - currentNodeId.value = newCurrentId + // 优先使用活跃节点,否则回退到根节点 + currentNodeId.value = newCurrentId || rootNodeId } } diff --git a/prompto-lab-ui/src/components/Chat/ChatTree.vue b/prompto-lab-ui/src/components/Chat/ChatTree.vue index e634c9f..4eef6f6 100644 --- a/prompto-lab-ui/src/components/Chat/ChatTree.vue +++ b/prompto-lab-ui/src/components/Chat/ChatTree.vue @@ -78,7 +78,13 @@ const emit = defineEmits<{ }>() const rootNode = computed(() => { - return props.conversationTree.get('root') + // 动态查找根节点(没有parentId的节点) + for (const [id, node] of props.conversationTree) { + if (!node.parentId) { + return node + } + } + return null }) diff --git a/prompto-lab-ui/src/components/Chat/MindMapTree.vue b/prompto-lab-ui/src/components/Chat/MindMapTree.vue index e7d9cf3..2f94b74 100644 --- a/prompto-lab-ui/src/components/Chat/MindMapTree.vue +++ b/prompto-lab-ui/src/components/Chat/MindMapTree.vue @@ -91,7 +91,7 @@ { }) } - const rootNode = props.conversationTree.get('root') - if (rootNode) { - calculateLayout('root', 0, 0) + // 动态查找根节点(没有parentId的节点) + let rootNodeId = '' + for (const [id, node] of props.conversationTree) { + if (!node.parentId) { + rootNodeId = id + break + } + } + if (rootNodeId) { + calculateLayout(rootNodeId, 0, 0) } return nodes diff --git a/prompto-lab-ui/src/components/Chat/TreeNode.vue b/prompto-lab-ui/src/components/Chat/TreeNode.vue index b19f614..cd0035a 100644 --- a/prompto-lab-ui/src/components/Chat/TreeNode.vue +++ b/prompto-lab-ui/src/components/Chat/TreeNode.vue @@ -18,7 +18,7 @@ {{ truncatedText }} -
+