diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java index 0417800..7e845de 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java @@ -49,41 +49,100 @@ public class UserInteractionController { * 建立SSE连接 */ @GetMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public SseEmitter streamConversation(@RequestParam(required = false) String sessionId) { - log.info("建立SSE连接 - 会话ID: {}", sessionId); + public SseEmitter streamConversation(@RequestParam(required = false) String sessionId, + @RequestParam String userId) { + log.info("建立SSE连接 - 会话ID: {}, 用户ID: {}", sessionId, userId); - if(sessionId == null || sessionId.isEmpty()) { - sessionId = UUID.randomUUID().toString(); - } - SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); - sseEmitters.put(sessionId, emitter); + boolean isNewSession = false; + ConversationSession session = null; - // 连接建立时发送欢迎消息 try { + if (sessionId == null || sessionId.isEmpty()) { + // 新建会话 + session = sessionManagementService.createNewSession(userId); + sessionId = session.getSessionId(); + isNewSession = true; + log.info("创建新会话 - 用户ID: {}, 会话ID: {}", userId, sessionId); + } else { + // 使用现有会话 + session = sessionManagementService.validateAndGetSession(userId, sessionId); + if (session == null) { + log.warn("会话不存在或无效 - 用户ID: {}, 会话ID: {}", userId, sessionId); + // 创建新会话作为fallback + session = sessionManagementService.createNewSession(userId != null ? userId : "anonymous_" + UUID.randomUUID().toString().substring(0, 8)); + sessionId = session.getSessionId(); + isNewSession = true; + } + } + + SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); + sseEmitters.put(sessionId, emitter); + + // 连接建立时发送会话信息 + Map connectionData = new ConcurrentHashMap<>(); + connectionData.put("sessionId", sessionId); + connectionData.put("userId", session.getUserId()); + connectionData.put("isNewSession", isNewSession); + connectionData.put("timestamp", System.currentTimeMillis()); + + // 根据会话状态返回nodeId + if (isNewSession) { + // 新会话返回根节点ID + connectionData.put("nodeId", "root"); + log.info("新会话返回根节点ID: root - 会话: {}", sessionId); + } else if (session.getQaTree() != null && session.getQaTree().getRoot() != null) { + // 已存在会话,返回根节点ID(因为qaTree只有根节点) + String rootNodeId = session.getQaTree().getRoot().getId(); + connectionData.put("nodeId", rootNodeId); + log.info("已存在会话返回根节点ID: {} - 会话: {}", rootNodeId, sessionId); + + // 返回qaTree + try { + String qaTreeJson = io.github.timemachinelab.util.QaTreeSerializeUtil.serialize(session.getQaTree()); + connectionData.put("qaTree", qaTreeJson); + } catch (Exception e) { + log.error("序列化qaTree失败: {}", e.getMessage()); + } + } else { + // 兜底情况,返回根节点ID + connectionData.put("nodeId", "root"); + log.info("兜底返回根节点ID: root - 会话: {}", sessionId); + } + emitter.send(SseEmitter.event() - .name("connected") - .data("SSE连接已建立,会话ID: " + sessionId)); - } catch (IOException e) { - log.error("发送欢迎消息失败: {}", e.getMessage()); - } - - // 设置连接事件处理 - String finalSessionId = sessionId; - emitter.onCompletion(() -> { - log.info("SSE连接完成: {}", finalSessionId); - }); - - emitter.onTimeout(() -> { - log.info("SSE连接超时: {}", finalSessionId); - sseEmitters.remove(finalSessionId); - }); - - emitter.onError((ex) -> { - log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage()); - sseEmitters.remove(finalSessionId); - }); - - return emitter; + .name("connected") + .data(connectionData)); + + // 设置连接事件处理 + String finalSessionId = sessionId; + emitter.onCompletion(() -> { + log.info("SSE连接完成: {}", finalSessionId); + }); + + emitter.onTimeout(() -> { + log.info("SSE连接超时: {}", finalSessionId); + sseEmitters.remove(finalSessionId); + }); + + emitter.onError((ex) -> { + log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage()); + sseEmitters.remove(finalSessionId); + }); + + return emitter; + + } catch (Exception e) { + log.error("建立SSE连接失败: {}", e.getMessage()); + SseEmitter errorEmitter = new SseEmitter(Long.MAX_VALUE); + try { + errorEmitter.send(SseEmitter.event() + .name("error") + .data("连接建立失败: " + e.getMessage())); + } catch (IOException ioException) { + log.error("发送错误消息失败: {}", ioException.getMessage()); + } + return errorEmitter; + } } /** @@ -131,15 +190,50 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe request.getNodeId(), request.getQuestionType()); - // 1. 会话管理和验证 + // 1. 强制要求sessionId + if (request.getSessionId() == null || request.getSessionId().trim().isEmpty()) { + log.warn("缺少必需的sessionId参数"); + return ResponseEntity.badRequest().body("sessionId参数是必需的"); + } + + // 2. 会话管理和验证 String userId = request.getUserId(); + if (userId == null || userId.trim().isEmpty()) { + log.warn("缺少必需的userId参数"); + return ResponseEntity.badRequest().body("userId参数是必需的"); + } - ConversationSession session = sessionManagementService.getOrCreateSession(userId, request.getSessionId()); + // 3. 验证会话是否存在 + ConversationSession session = sessionManagementService.validateAndGetSession(userId, request.getSessionId()); + if (session == null) { + log.warn("会话不存在或无效 - 用户ID: {}, 会话ID: {}", userId, request.getSessionId()); + return ResponseEntity.badRequest().body("会话不存在或无效"); + } - // 2. 验证nodeId是否属于该会话 - if (request.getNodeId() != null && !sessionManagementService.validateNodeId(session.getSessionId(), request.getNodeId())) { - log.warn("无效的节点ID - 会话: {}, 节点: {}", session.getSessionId(), request.getNodeId()); - return ResponseEntity.badRequest().body("无效的节点ID"); + // 4. nodeId验证逻辑 + String nodeId = request.getNodeId(); + if (nodeId == null || nodeId.trim().isEmpty()) { + // nodeId为空,表示这是新建会话的第一个问题 + if (session.getQaTree() != null && session.getQaTree().getRoot() != null) { + log.warn("会话已存在qaTree,但nodeId为空 - 会话: {}", session.getSessionId()); + return ResponseEntity.badRequest().body("现有会话必须提供nodeId"); + } + log.info("新建会话的第一个问题 - 会话: {}", session.getSessionId()); + } else if ("root".equals(nodeId)) { + // nodeId为'root',表示这是根节点的回答 + if (session.getQaTree() == null || session.getQaTree().getRoot() == null) { + log.info("根节点回答,但qaTree未初始化 - 会话: {}", session.getSessionId()); + // 允许继续处理,后续会创建qaTree + } else { + log.info("根节点回答 - 会话: {}", session.getSessionId()); + } + } else { + // nodeId不为空且不是'root',验证是否属于该会话 + if (!sessionManagementService.validateNodeId(session.getSessionId(), nodeId)) { + log.warn("无效的节点ID - 会话: {}, 节点: {}", session.getSessionId(), nodeId); + return ResponseEntity.badRequest().body("无效的节点ID"); + } + log.info("更新现有节点 - 会话: {}, 节点: {}", session.getSessionId(), nodeId); } // 3. 验证答案格式 diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java index f54d2f7..5ffa0be 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java @@ -2,8 +2,14 @@ import io.github.timemachinelab.core.question.BaseQuestion; import io.github.timemachinelab.core.question.InputQuestion; +import io.github.timemachinelab.core.question.SingleChoiceQuestion; +import io.github.timemachinelab.core.question.MultipleChoiceQuestion; +import io.github.timemachinelab.core.question.FormQuestion; +import io.github.timemachinelab.core.session.domain.entity.ConversationSession; import org.springframework.stereotype.Component; +import java.util.List; + @Component public class QaTreeDomain { @@ -15,9 +21,72 @@ public QaTree createTree(String userStartQuestion) { return new QaTree(startNode); } + + /** + * 使用ConversationSession的自增ID创建QaTree + * @param userStartQuestion 用户开始问题 + * @param session 会话对象,用于获取自增ID + * @return 创建的QaTree + */ + public QaTree createTree(String userStartQuestion, ConversationSession session) { + InputQuestion startQA = new InputQuestion(); + startQA.setQuestion(userStartQuestion); + startQA.setAnswer(userStartQuestion); + // 使用会话的自增ID创建根节点 + String rootNodeId = session.getNextNodeId(); + QaTreeNode startNode = new QaTreeNode(startQA, rootNodeId); + + return new QaTree(startNode); + } public QaTree appendNode(QaTree tree, String parentId, BaseQuestion qa) { tree.addNode(parentId, new QaTreeNode(qa)); return tree; } + + /** + * 使用ConversationSession的自增ID向QaTree添加节点 + * @param tree QA树 + * @param parentId 父节点ID + * @param qa 问题对象 + * @param session 会话对象,用于获取自增ID + * @return 更新后的QaTree + */ + public QaTree appendNode(QaTree tree, String parentId, BaseQuestion qa, ConversationSession session) { + String nodeId = session.getNextNodeId(); + tree.addNode(parentId, new QaTreeNode(qa, nodeId)); + return tree; + } + + /** + * 更新指定节点的答案 + * @param tree QA树 + * @param nodeId 节点ID + * @param answer 新的答案内容 + * @return 是否更新成功 + */ + public boolean updateNodeAnswer(QaTree tree, String nodeId, Object answer) { + QaTreeNode node = tree.getNodeById(nodeId); + if (node == null) { + return false; + } + + BaseQuestion qa = node.getQa(); + if (qa == null) { + return false; + } + + // 根据问题类型设置答案 + if (qa instanceof InputQuestion) { + ((InputQuestion) qa).setAnswer((String) answer); + } else if (qa instanceof SingleChoiceQuestion) { + ((SingleChoiceQuestion) qa).setAnswer((List) answer); + } else if (qa instanceof MultipleChoiceQuestion) { + ((MultipleChoiceQuestion) qa).setAnswer((List) answer); + } else if (qa instanceof FormQuestion) { + ((FormQuestion) qa).setAnswer((List) answer); + } + + return true; + } } diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeNode.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeNode.java index ebaa84b..de429c3 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeNode.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeNode.java @@ -20,7 +20,23 @@ public QaTreeNode(BaseQuestion qa) { this.children = new HashMap<>(); this.qa = qa; } + + /** + * 使用指定ID创建节点的构造函数 + * @param qa 问题对象 + * @param nodeId 指定的节点ID + */ + public QaTreeNode(BaseQuestion qa, String nodeId) { + this.id = nodeId; + this.children = new HashMap<>(); + this.qa = qa; + } + /** + * @deprecated 此方法使用UUID生成节点ID,不符合自增ID规范。 + * 请使用 QaTreeDomain.appendNode(tree, parentId, qa, session) 方法代替。 + */ + @Deprecated public void append(BaseQuestion qa) { QaTreeNode node = new QaTreeNode(qa); this.append(node); diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java index 3bc82d3..7ca17e5 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java @@ -35,7 +35,7 @@ public class ConversationService { public void processUserMessage(String userId, String userMessage, Consumer sseCallback) { - ConversationSession session = sessionManagementService.getUserCurrentSession(userId); + ConversationSession session = sessionManagementService.getUserLatestSession(userId); if (session == null) { log.warn("会话不存在"); return; diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java index b498aef..d00d7d4 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java @@ -8,9 +8,12 @@ import org.springframework.util.StringUtils; import javax.annotation.Resource; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.Map; import java.util.UUID; +import java.util.ArrayList; +import java.util.stream.Collectors; /** * 会话管理服务 @@ -23,8 +26,8 @@ @Slf4j public class SessionManagementService { - // 用户ID到会话ID的映射 - private final Map userSessionMap = new ConcurrentHashMap<>(); + // 用户ID到会话ID列表的映射(一对多关系) + private final Map> userSessionMap = new ConcurrentHashMap<>(); // 会话存储 private final Map sessions = new ConcurrentHashMap<>(); @@ -73,80 +76,149 @@ public boolean validateNodeId(String sessionId, String nodeId) { } /** - * 获取用户当前会话 + * 获取用户的所有会话 * * @param userId 用户ID - * @return 会话对象 + * @return 会话列表 + */ + public List getUserSessions(String userId) { + List sessionIds = userSessionMap.get(userId); + if (sessionIds == null || sessionIds.isEmpty()) { + return new ArrayList<>(); + } + return sessionIds.stream() + .map(sessions::get) + .filter(session -> session != null) + .collect(Collectors.toList()); + } + + /** + * 获取用户最新的会话(最后创建的会话) + * + * @param userId 用户ID + * @return 最新的会话对象,如果没有会话则返回null */ - public ConversationSession getUserCurrentSession(String userId) { - String sessionId = userSessionMap.get(userId); - return sessionId != null ? sessions.get(sessionId) : null; + public ConversationSession getUserLatestSession(String userId) { + List sessionIds = userSessionMap.get(userId); + if (sessionIds == null || sessionIds.isEmpty()) { + return null; + } + // 返回列表中最后一个会话(最新创建的) + String latestSessionId = sessionIds.get(sessionIds.size() - 1); + return sessions.get(latestSessionId); } /** * 创建新会话 */ - private ConversationSession createNewSession(String userId) { - // 如果用户已有会话,先清理旧的映射 - String oldSessionId = userSessionMap.get(userId); - if (oldSessionId != null) { - sessions.remove(oldSessionId); - log.info("清理用户旧会话 - 用户: {}, 旧会话: {}", userId, oldSessionId); - } - + public ConversationSession createNewSession(String userId) { // 生成新的sessionId String newSessionId = UUID.randomUUID().toString(); - QaTree tree = createDefaultQaTree(); - // 创建新会话 - ConversationSession session = new ConversationSession(userId, newSessionId, tree); + // 先创建会话对象(qaTree为null) + ConversationSession session = new ConversationSession(userId, newSessionId, null); + + // 使用会话的自增ID创建QaTree,确保根节点ID=1 + QaTree tree = qaTreeDomain.createTree("default", session); + + // 设置QaTree到会话中 + session.setQaTree(tree); - // 建立映射关系 - userSessionMap.put(userId, session.getSessionId()); + // 建立映射关系 - 添加到用户的会话列表中 + userSessionMap.computeIfAbsent(userId, k -> new ArrayList<>()).add(session.getSessionId()); sessions.put(session.getSessionId(), session); - log.info("创建新会话 - 用户: {}, 会话: {}", userId, session.getSessionId()); + log.info("创建新会话 - 用户: {}, 会话: {}, 根节点ID: 1", userId, session.getSessionId()); return session; } /** * 验证并获取现有会话 + * 如果会话不存在或不属于该用户,返回null */ public ConversationSession validateAndGetSession(String userId, String sessionId) { ConversationSession session = sessions.get(sessionId); if (session == null) { - log.warn("会话不存在,创建新会话 - 用户: {}, 请求会话: {}", userId, sessionId); - return createNewSession(userId); + log.warn("会话不存在 - 用户: {}, 请求会话: {}", userId, sessionId); + return null; } // 验证会话是否属于该用户 if (!session.getUserId().equals(userId)) { - log.warn("会话不属于该用户,创建新会话 - 用户: {}, 会话: {}, 会话所有者: {}", + log.warn("会话不属于该用户 - 用户: {}, 会话: {}, 会话所有者: {}", userId, sessionId, session.getUserId()); - return createNewSession(userId); + return null; } - // 更新用户会话映射(防止映射不一致) - userSessionMap.put(userId, sessionId); + // 确保用户会话映射中包含该会话ID(防止映射不一致) + List userSessions = userSessionMap.computeIfAbsent(userId, k -> new ArrayList<>()); + if (!userSessions.contains(sessionId)) { + userSessions.add(sessionId); + } log.debug("验证会话成功 - 用户: {}, 会话: {}", userId, sessionId); return session; } /** - * 清理会话 + * 清理指定会话 * * @param sessionId 会话ID */ public void removeSession(String sessionId) { ConversationSession session = sessions.remove(sessionId); if (session != null) { - userSessionMap.remove(session.getUserId()); + // 从用户会话列表中移除该会话ID + List userSessions = userSessionMap.get(session.getUserId()); + if (userSessions != null) { + userSessions.remove(sessionId); + // 如果用户没有其他会话了,移除整个映射 + if (userSessions.isEmpty()) { + userSessionMap.remove(session.getUserId()); + } + } log.info("清理会话 - 用户: {}, 会话: {}", session.getUserId(), sessionId); } } + /** + * 清理用户的所有会话 + * + * @param userId 用户ID + */ + public void removeAllUserSessions(String userId) { + List sessionIds = userSessionMap.remove(userId); + if (sessionIds != null) { + for (String sessionId : sessionIds) { + sessions.remove(sessionId); + } + log.info("清理用户所有会话 - 用户: {}, 会话数量: {}", userId, sessionIds.size()); + } + } + + /** + * 根据会话ID获取会话对象 + * + * @param sessionId 会话ID + * @return 会话对象,如果不存在则返回null + */ + public ConversationSession getSessionById(String sessionId) { + return sessions.get(sessionId); + } + + /** + * 检查用户是否拥有指定的会话 + * + * @param userId 用户ID + * @param sessionId 会话ID + * @return 是否拥有该会话 + */ + public boolean userOwnsSession(String userId, String sessionId) { + List userSessions = userSessionMap.get(userId); + return userSessions != null && userSessions.contains(sessionId); + } + /** * 获取会话统计信息 */ @@ -154,6 +226,18 @@ public Map getSessionStats() { Map stats = new ConcurrentHashMap<>(); stats.put("totalSessions", sessions.size()); stats.put("activeUsers", userSessionMap.size()); + + // 计算每个用户的会话数量分布 + Map userSessionCounts = new ConcurrentHashMap<>(); + int totalUserSessions = 0; + for (Map.Entry> entry : userSessionMap.entrySet()) { + int sessionCount = entry.getValue().size(); + userSessionCounts.put(entry.getKey(), sessionCount); + totalUserSessions += sessionCount; + } + + stats.put("userSessionCounts", userSessionCounts); + stats.put("averageSessionsPerUser", userSessionMap.isEmpty() ? 0 : (double) totalUserSessions / userSessionMap.size()); stats.put("timestamp", System.currentTimeMillis()); return stats; } diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java index e39aef2..b132483 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java @@ -40,6 +40,12 @@ public String processAnswer(UnifiedAnswerRequest request) { return "无效的回答格式"; } + // 获取会话并更新qaTree + ConversationSession session = sessionManagementService.validateAndGetSession(request.getUserId(), request.getSessionId()); + if (session != null) { + updateQaTreeWithAnswer(session, request); + } + // 将答案转换为可读文本 String readableText = request.toReadableText(); @@ -53,6 +59,56 @@ public String processAnswer(UnifiedAnswerRequest request) { return message.toString(); } + /** + * 更新qaTree中的答案 + */ + private void updateQaTreeWithAnswer(ConversationSession session, UnifiedAnswerRequest request) { + try { + QaTree qaTree = session.getQaTree(); + if (qaTree == null) { + log.warn("会话的qaTree为空,无法更新答案 - 会话ID: {}", session.getSessionId()); + return; + } + + String nodeId = request.getNodeId(); + // 如果nodeId为'root',使用根节点ID + if ("root".equals(nodeId) && qaTree.getRoot() != null) { + nodeId = qaTree.getRoot().getId(); + } + + // 根据问题类型准备答案数据 + Object answerData = prepareAnswerData(request); + + // 更新节点答案 + boolean updated = qaTreeDomain.updateNodeAnswer(qaTree, nodeId, answerData); + if (updated) { + log.info("成功更新qaTree节点答案 - 会话ID: {}, 节点ID: {}", session.getSessionId(), nodeId); + // 会话数据已在内存中更新,无需额外保存操作 + } else { + log.warn("更新qaTree节点答案失败 - 会话ID: {}, 节点ID: {}", session.getSessionId(), nodeId); + } + } catch (Exception e) { + log.error("更新qaTree答案时发生异常 - 会话ID: {}, 错误: {}", session.getSessionId(), e.getMessage(), e); + } + } + + /** + * 根据问题类型准备答案数据 + */ + private Object prepareAnswerData(UnifiedAnswerRequest request) { + switch (request.getQuestionType().toLowerCase()) { + case "input": + return request.getInputAnswer(); + case "single": + case "multi": + return request.getChoiceAnswer(); + case "form": + return request.getFormAnswer(); + default: + return request.getAnswerString(); + } + } + @Override public String preprocessMessage(String originalMessage, UnifiedAnswerRequest answerRequest,ConversationSession conversationSession) { try { diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/domain/entity/ConversationSession.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/domain/entity/ConversationSession.java index 33d4642..18ba489 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/domain/entity/ConversationSession.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/domain/entity/ConversationSession.java @@ -5,16 +5,20 @@ import java.time.LocalDateTime; import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; @Getter public class ConversationSession { private final String sessionId; private final String userId; - private final QaTree qaTree; + private QaTree qaTree; // 移除final,允许后续设置 private final LocalDateTime createTime; private LocalDateTime updateTime; + // 节点ID自增计数器,从1开始 + private final AtomicInteger nodeIdCounter = new AtomicInteger(0); + public ConversationSession(String userId, String sessionId, QaTree qaTree) { this.qaTree = qaTree; this.sessionId = sessionId; @@ -22,4 +26,20 @@ public ConversationSession(String userId, String sessionId, QaTree qaTree) { this.createTime = LocalDateTime.now(); this.updateTime = LocalDateTime.now(); } + + /** + * 获取下一个节点ID(自增) + * @return 自增的节点ID字符串 + */ + public String getNextNodeId() { + return String.valueOf(nodeIdCounter.incrementAndGet()); + } + + /** + * 设置QaTree(仅用于初始化) + * @param qaTree QA树对象 + */ + public void setQaTree(QaTree qaTree) { + this.qaTree = qaTree; + } } \ No newline at end of file diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/infrastructure/web/ConversationController.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/infrastructure/web/ConversationController.java deleted file mode 100644 index c8f6bef..0000000 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/infrastructure/web/ConversationController.java +++ /dev/null @@ -1,79 +0,0 @@ -package io.github.timemachinelab.core.session.infrastructure.web; - -import io.github.timemachinelab.core.session.application.ConversationService; -import io.github.timemachinelab.core.session.application.MessageProcessingService; -import io.github.timemachinelab.core.session.application.SessionManagementService; -import io.github.timemachinelab.core.session.domain.entity.ConversationSession; -import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageRequest; -import io.github.timemachinelab.core.session.infrastructure.web.dto.UnifiedAnswerRequest; -import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.springframework.http.MediaType; -import org.springframework.http.ResponseEntity; -import org.springframework.validation.annotation.Validated; -import org.springframework.web.bind.annotation.*; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; - -import java.io.IOException; -import java.util.concurrent.ConcurrentHashMap; -import java.util.Map; - -@RestController -@RequestMapping("/api/conversation") -@RequiredArgsConstructor -@Slf4j -public class ConversationController { - - private final ConversationService conversationService; - private final MessageProcessingService messageProcessingService; - private final SessionManagementService sessionManagementService; - private final Map sseEmitters = new ConcurrentHashMap<>(); - - @GetMapping(value = "/sse/{sessionId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public SseEmitter streamConversation(@PathVariable String sessionId) { - SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); - sseEmitters.put(sessionId, emitter); - - emitter.onCompletion(() -> sseEmitters.remove(sessionId)); - emitter.onTimeout(() -> sseEmitters.remove(sessionId)); - emitter.onError((ex) -> sseEmitters.remove(sessionId)); - - return emitter; - } - - - - /** - * 从MessageRequest中提取用户ID - */ - private String extractUserIdFromMessageRequest(MessageRequest request) { - // 临时实现:使用会话ID作为用户ID(仅用于演示) - return request.getSessionId() != null ? "user_" + request.getSessionId().hashCode() : "anonymous_" + System.currentTimeMillis(); - } - - /** - * 从请求中提取用户ID - * TODO: 根据您的认证机制实现此方法 - */ - private String extractUserIdFromRequest(UnifiedAnswerRequest request) { - // 这里需要根据您的认证机制来实现 - // 可能从JWT token、session、或其他方式获取用户ID - - // 临时实现:使用会话ID作为用户ID(仅用于演示) - return request.getSessionId() != null ? "user_" + request.getSessionId().hashCode() : "anonymous_" + System.currentTimeMillis(); - } - - private void sendSseMessage(String sessionId, MessageResponse response) { - SseEmitter emitter = sseEmitters.get(sessionId); - if (emitter != null) { - try { - emitter.send(SseEmitter.event() - .name("message") - .data(response)); - } catch (IOException e) { - sseEmitters.remove(sessionId); - } - } - } -} \ No newline at end of file diff --git a/prompto-lab-ui/src/components/Chat/AIChatPage.vue b/prompto-lab-ui/src/components/Chat/AIChatPage.vue index 605a681..c1fae86 100644 --- a/prompto-lab-ui/src/components/Chat/AIChatPage.vue +++ b/prompto-lab-ui/src/components/Chat/AIChatPage.vue @@ -120,6 +120,12 @@ const eventSource = ref(null) const isConnected = ref(false) const isInitializing = ref(false) +// SSE连接管理 +const connectionTimeout = ref(null) +const activityTimeout = ref(null) +const ACTIVITY_TIMEOUT = 5 * 60 * 1000 // 5分钟不活跃超时 +const lastActivityTime = ref(Date.now()) + // 对话树存储所有节点 const conversationTree = ref>(new Map()) const currentNodeId = ref('') @@ -128,6 +134,64 @@ const isLoading = ref(false) // 问题状态管理 const currentQuestion = ref(null) +// 确保SSE连接唯一性 +const ensureUniqueConnection = () => { + if (eventSource.value) { + console.log('关闭现有SSE连接以确保唯一性') + closeSSE(eventSource.value) + eventSource.value = null + isConnected.value = false + } + + // 清理定时器 + if (connectionTimeout.value) { + clearTimeout(connectionTimeout.value) + connectionTimeout.value = null + } + if (activityTimeout.value) { + clearTimeout(activityTimeout.value) + activityTimeout.value = null + } +} + +// 更新活跃时间 +const updateActivity = () => { + lastActivityTime.value = Date.now() + + // 重置活跃超时定时器 + if (activityTimeout.value) { + clearTimeout(activityTimeout.value) + } + + activityTimeout.value = setTimeout(() => { + console.log('SSE连接因不活跃超时,自动关闭') + closeConnection() + toast.info({ + title: '连接已关闭', + message: '由于长时间无活动,连接已自动关闭', + duration: 3000 + }) + }, ACTIVITY_TIMEOUT) +} + +// 关闭连接 +const closeConnection = () => { + if (eventSource.value) { + closeSSE(eventSource.value) + eventSource.value = null + } + isConnected.value = false + + if (connectionTimeout.value) { + clearTimeout(connectionTimeout.value) + connectionTimeout.value = null + } + if (activityTimeout.value) { + clearTimeout(activityTimeout.value) + activityTimeout.value = null + } +} + // 初始化会话 const initializeSession = async () => { if (isInitializing.value) return @@ -135,40 +199,28 @@ const initializeSession = async () => { isInitializing.value = true try { - // 创建新会话 + // 确保连接唯一性 + ensureUniqueConnection() + + // 生成用户ID(如果没有的话) const userId = 'demo-user-' + Date.now() // 临时用户ID - session.value = await startConversation(userId) - - // 建立SSE连接 + + // 建立SSE连接(不传sessionId,让后端创建新会话) eventSource.value = connectUserInteractionSSE( - session.value.sessionId, + null, // sessionId为null,后端会创建新会话 + userId, handleSSEMessage, handleSSEError ) - isConnected.value = true - - // 初始化根节点 - const rootNode: ConversationNode = { - id: 'root', - content: '您好!我是AI助手,有什么可以帮助您的吗?', - type: 'assistant', - timestamp: new Date(), - children: [], - isActive: true - } - - conversationTree.value.set('root', rootNode) - currentNodeId.value = 'root' - - toast.success({ - title: '会话已建立', - message: '已成功连接到AI助手', - duration: 2000 - }) + // 启动活跃监控 + updateActivity() + + console.log('SSE连接已建立,等待后端返回会话信息...') } catch (error: any) { console.error('初始化会话失败:', error) + isConnected.value = false toast.error({ title: '连接失败', message: '无法连接到AI服务,请刷新页面重试', @@ -180,15 +232,86 @@ const initializeSession = async () => { } // 处理SSE消息 -const handleSSEMessage = (response: MessageResponse) => { +const handleSSEMessage = (response: any) => { console.log('收到SSE消息:', response) + + // 更新活跃时间 + updateActivity() + + // 处理连接建立消息 + if (response.type === 'connected' || response.sessionId) { + // 这是连接建立时的会话信息 + if (response.sessionId) { + session.value = { + sessionId: response.sessionId, + userId: response.userId || 'demo-user-' + Date.now() + } + isConnected.value = true + + console.log('会话已建立:', session.value) + + // 后端总是会返回nodeId,新会话返回'root',已存在会话返回实际的nodeId + if (response.nodeId) { + currentNodeId.value = response.nodeId + console.log('会话节点ID:', response.nodeId) + + // 如果是根节点,初始化根节点 + if (response.nodeId === 'root') { + const rootNode: ConversationNode = { + id: 'root', + content: '您好!我是AI助手,有什么可以帮助您的吗?', + type: 'assistant', + timestamp: new Date(), + children: [], + isActive: true + } + conversationTree.value.set('root', rootNode) + } + } + + // 如果有现有的qaTree,恢复对话树 + if (response.qaTree) { + try { + // 这里需要根据实际的qaTree格式来解析和恢复对话树 + console.log('恢复现有对话树:', response.qaTree) + // TODO: 实现qaTree的解析和恢复逻辑 + } catch (error) { + console.error('恢复对话树失败:', error) + } + } + + toast.success({ + title: '会话已建立', + message: response.isNewSession ? '已创建新会话' : '已连接到现有会话', + duration: 2000 + }) + } + return + } - // 根据消息类型处理 - switch (response.type) { + // 处理其他类型的消息 + // 检查是否是新的问题格式(包含question对象) + if (response.question && response.question.type) { + // 这是新的问题格式 + currentQuestion.value = response.question + + // 更新当前节点ID为问题的parentId + if (response.parentId) { + currentNodeId.value = response.parentId + console.log('更新当前节点ID为:', response.parentId) + } + + isLoading.value = false + console.log('收到新格式问题:', response.question, '父节点ID:', response.parentId) + return + } + + const messageResponse = response as MessageResponse + switch (messageResponse.type) { case 'AI_QUESTION': // 尝试解析问题内容为问题对象 try { - const questionData = JSON.parse(response.content) + const questionData = JSON.parse(messageResponse.content) if (questionData.type && ['input', 'single', 'multi', 'form'].includes(questionData.type)) { currentQuestion.value = questionData isLoading.value = false @@ -198,13 +321,13 @@ const handleSSEMessage = (response: MessageResponse) => { console.log('非JSON格式的问题,作为普通消息处理') } // 如果不是问题格式,作为普通消息处理 - addAIMessage(response.nodeId, response.content) + addAIMessage(messageResponse.nodeId, messageResponse.content) break case 'AI_ANSWER': // 添加对AI_ANSWER类型的处理 - addAIMessage(response.nodeId, response.content) + addAIMessage(messageResponse.nodeId, messageResponse.content) break case 'AI_SELECTION_QUESTION': - addAISelectionMessage(response.nodeId, response.content, response.options || []) + addAISelectionMessage(messageResponse.nodeId, messageResponse.content, messageResponse.options || []) break case 'USER_ANSWER': // 用户消息确认,通常不需要特殊处理 @@ -212,12 +335,19 @@ const handleSSEMessage = (response: MessageResponse) => { case 'SYSTEM_INFO': toast.info({ title: '系统消息', - message: response.content, + message: messageResponse.content, duration: 3000 }) break + case undefined: + // 处理没有type字段的消息 + console.log('收到没有type字段的消息,尝试作为普通消息处理:', response) + if (response.content) { + addAIMessage(response.nodeId || `ai_${Date.now()}`, response.content) + } + break default: - console.warn('未知的消息类型:', response.type, response) + console.warn('未知的消息类型:', messageResponse.type, messageResponse) break } } @@ -226,6 +356,12 @@ const handleSSEMessage = (response: MessageResponse) => { const handleSSEError = (error: Event) => { console.error('SSE连接错误:', error) isConnected.value = false + + // 清理定时器 + if (activityTimeout.value) { + clearTimeout(activityTimeout.value) + activityTimeout.value = null + } toast.error({ title: '连接中断', @@ -233,16 +369,25 @@ const handleSSEError = (error: Event) => { duration: 3000 }) - // 尝试重连 - setTimeout(() => { - if (session.value && !isConnected.value) { - eventSource.value = connectUserInteractionSSE( - session.value.sessionId, - handleSSEMessage, - handleSSEError - ) - } - }, 3000) + // 尝试重连(如果有会话信息) + if (session.value && !isInitializing.value) { + setTimeout(() => { + if (!isConnected.value && !isInitializing.value) { + console.log('尝试重连到现有会话:', session.value?.sessionId) + ensureUniqueConnection() // 确保连接唯一性 + + eventSource.value = connectUserInteractionSSE( + session.value?.sessionId || null, + session.value?.userId || userId, + handleSSEMessage, + handleSSEError + ) + + // 重新启动活跃监控 + updateActivity() + } + }, 3000) + } } // 添加AI消息到对话树 @@ -393,9 +538,15 @@ const handleSendMessage = async (content: string) => { return } + // 更新活跃时间 + updateActivity() + // 重置当前问题状态,进入新的对话 currentQuestion.value = null + // 后端总是返回nodeId,前端也总是传递nodeId + const nodeIdToSend = currentNodeId.value + const userNodeId = `user_${Date.now()}` const userNode: ConversationNode = { id: userNodeId, @@ -426,14 +577,14 @@ const handleSendMessage = async (content: string) => { isLoading.value = true try { - // 发送消息到后端 + // 发送消息到后端,必须包含sessionId const messageRequest: MessageRequest = { - sessionId: session.value.sessionId, + sessionId: session.value.sessionId, // 必需的sessionId content, type: 'USER_TEXT' } - await sendUserMessage(messageRequest) + await sendUserMessage(messageRequest, session.value.userId, nodeIdToSend) // 消息发送成功,等待SSE返回AI回复 console.log('消息已发送,等待AI回复...') @@ -442,11 +593,23 @@ const handleSendMessage = async (content: string) => { console.error('发送消息失败:', error) isLoading.value = false - toast.error({ - title: '发送失败', - message: error.message || '消息发送失败,请重试', - duration: 4000 - }) + // 检查是否是会话相关错误 + if (error.message && (error.message.includes('sessionId') || error.message.includes('会话'))) { + toast.error({ + title: '会话异常', + message: '会话已失效,请刷新页面重新建立连接', + duration: 5000 + }) + // 清理当前会话状态 + session.value = null + closeConnection() + } else { + toast.error({ + title: '发送失败', + message: error.message || '消息发送失败,请重试', + duration: 4000 + }) + } // 发送失败时移除用户消息节点 conversationTree.value.delete(userNodeId) @@ -470,16 +633,19 @@ const handleSubmitAnswer = async (answerData: any) => { return } + // 更新活跃时间 + updateActivity() + isLoading.value = true try { - // 构建统一答案请求 + // 构建统一答案请求,必须包含sessionId和正确的nodeId const request: UnifiedAnswerRequest = { - sessionId: session.value.sessionId, - nodeId: currentNodeId.value, + sessionId: session.value.sessionId, // 必需的sessionId + nodeId: currentNodeId.value, // 当前节点ID,用于后端验证 questionType: currentQuestion.value.type, answer: answerData, - userId: session.value.userId + userId: session.value.userId // 必需的userId } // 调用新的processAnswer接口 @@ -557,11 +723,23 @@ const handleSubmitAnswer = async (answerData: any) => { console.error('提交答案失败:', error) isLoading.value = false - toast.error({ - title: '提交失败', - message: error.message || '答案提交失败,请重试', - duration: 4000 - }) + // 检查是否是会话或节点相关错误 + if (error.message && (error.message.includes('sessionId') || error.message.includes('nodeId') || error.message.includes('会话') || error.message.includes('节点'))) { + toast.error({ + title: '会话异常', + message: '会话或节点状态异常,请刷新页面重新建立连接', + duration: 5000 + }) + // 清理当前会话状态 + session.value = null + closeConnection() + } else { + toast.error({ + title: '提交失败', + message: error.message || '答案提交失败,请重试', + duration: 4000 + }) + } } } diff --git a/prompto-lab-ui/src/services/conversationApi.ts b/prompto-lab-ui/src/services/conversationApi.ts index 491d5c2..648514b 100644 --- a/prompto-lab-ui/src/services/conversationApi.ts +++ b/prompto-lab-ui/src/services/conversationApi.ts @@ -172,13 +172,14 @@ export const connectSSE = (sessionId: string, onMessage: (response: MessageRespo * 发送用户消息到用户交互接口 * 对接后端的用户交互消息接口 */ -export const sendUserMessage = async (request: MessageRequest): Promise => { +export const sendUserMessage = async (request: MessageRequest, userId: string, nodeId?: string): Promise => { // 构建统一答案请求格式 const unifiedRequest: UnifiedAnswerRequest = { sessionId: request.sessionId, + nodeId: nodeId || undefined, // 添加nodeId参数 questionType: 'input', // 普通文本消息作为input类型 answer: request.content, - userId: request.sessionId // 使用sessionId作为userId + userId: userId // 使用正确的userId } const url = `${USER_INTERACTION_BASE}/message` @@ -212,13 +213,24 @@ export const processAnswer = async (request: UnifiedAnswerRequest): Promise void, onError?: (error: Event) => void): EventSource => { - const url = `${USER_INTERACTION_BASE}/sse/${sessionId}` +export const connectUserInteractionSSE = (sessionId: string | null, userId: string, onMessage: (response: MessageResponse) => void, onError?: (error: Event) => void): EventSource => { + const params = new URLSearchParams() + if (sessionId) { + params.append('sessionId', sessionId) + } + params.append('userId', userId) + const url = `${USER_INTERACTION_BASE}/sse?${params.toString()}` const eventSource = new EventSource(url) // 监听连接建立事件 eventSource.addEventListener('connected', (event: MessageEvent) => { console.log('用户交互SSE连接已建立:', event.data) + try { + const response = JSON.parse(event.data) + onMessage(response) + } catch (error) { + console.error('解析连接建立消息失败:', error) + } }) // 监听消息事件 diff --git a/prompto-lab-ui/src/services/userInteractionApi.ts b/prompto-lab-ui/src/services/userInteractionApi.ts index 85b72cd..6c844dd 100644 --- a/prompto-lab-ui/src/services/userInteractionApi.ts +++ b/prompto-lab-ui/src/services/userInteractionApi.ts @@ -76,10 +76,19 @@ export const sendAnswer = async (request: UnifiedAnswerRequest): Promise * 建立SSE连接 */ export const connectUserInteractionSSE = ( + sessionId: string | null, + userId: string, onMessage: (response: MessageResponse) => void, onError?: (error: Event) => void ): EventSource => { - const url = `${API_BASE}/sse` + // 构建查询参数 + const params = new URLSearchParams() + if (sessionId) { + params.append('sessionId', sessionId) + } + params.append('userId', userId) + + const url = `${API_BASE}/sse?${params.toString()}` const eventSource = new EventSource(url) // 监听连接建立事件 diff --git a/prompto-lab-ui/src/views/ApiConfigView.vue b/prompto-lab-ui/src/views/ApiConfigView.vue index 5c791a0..d835e77 100644 --- a/prompto-lab-ui/src/views/ApiConfigView.vue +++ b/prompto-lab-ui/src/views/ApiConfigView.vue @@ -131,9 +131,9 @@