Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,41 +49,100 @@ public class UserInteractionController {
* 建立SSE连接
*/
@GetMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter streamConversation(@RequestParam(required = false) String sessionId) {
log.info("建立SSE连接 - 会话ID: {}", sessionId);
public SseEmitter streamConversation(@RequestParam(required = false) String sessionId,
@RequestParam String userId) {
log.info("建立SSE连接 - 会话ID: {}, 用户ID: {}", sessionId, userId);

if(sessionId == null || sessionId.isEmpty()) {
sessionId = UUID.randomUUID().toString();
}
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
sseEmitters.put(sessionId, emitter);
boolean isNewSession = false;
ConversationSession session = null;

// 连接建立时发送欢迎消息
try {
if (sessionId == null || sessionId.isEmpty()) {
// 新建会话
session = sessionManagementService.createNewSession(userId);
sessionId = session.getSessionId();
isNewSession = true;
log.info("创建新会话 - 用户ID: {}, 会话ID: {}", userId, sessionId);
} else {
// 使用现有会话
session = sessionManagementService.validateAndGetSession(userId, sessionId);
if (session == null) {
log.warn("会话不存在或无效 - 用户ID: {}, 会话ID: {}", userId, sessionId);
// 创建新会话作为fallback
session = sessionManagementService.createNewSession(userId != null ? userId : "anonymous_" + UUID.randomUUID().toString().substring(0, 8));
sessionId = session.getSessionId();
isNewSession = true;
}
}

SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
sseEmitters.put(sessionId, emitter);

// 连接建立时发送会话信息
Map<String, Object> connectionData = new ConcurrentHashMap<>();
connectionData.put("sessionId", sessionId);
connectionData.put("userId", session.getUserId());
connectionData.put("isNewSession", isNewSession);
connectionData.put("timestamp", System.currentTimeMillis());

// 根据会话状态返回nodeId
if (isNewSession) {
// 新会话返回根节点ID
connectionData.put("nodeId", "root");
log.info("新会话返回根节点ID: root - 会话: {}", sessionId);
} else if (session.getQaTree() != null && session.getQaTree().getRoot() != null) {
// 已存在会话,返回根节点ID(因为qaTree只有根节点)
String rootNodeId = session.getQaTree().getRoot().getId();
connectionData.put("nodeId", rootNodeId);
log.info("已存在会话返回根节点ID: {} - 会话: {}", rootNodeId, sessionId);

// 返回qaTree
try {
String qaTreeJson = io.github.timemachinelab.util.QaTreeSerializeUtil.serialize(session.getQaTree());
connectionData.put("qaTree", qaTreeJson);
} catch (Exception e) {
log.error("序列化qaTree失败: {}", e.getMessage());
}
} else {
// 兜底情况,返回根节点ID
connectionData.put("nodeId", "root");
log.info("兜底返回根节点ID: root - 会话: {}", sessionId);
}

emitter.send(SseEmitter.event()
.name("connected")
.data("SSE连接已建立,会话ID: " + sessionId));
} catch (IOException e) {
log.error("发送欢迎消息失败: {}", e.getMessage());
}

// 设置连接事件处理
String finalSessionId = sessionId;
emitter.onCompletion(() -> {
log.info("SSE连接完成: {}", finalSessionId);
});

emitter.onTimeout(() -> {
log.info("SSE连接超时: {}", finalSessionId);
sseEmitters.remove(finalSessionId);
});

emitter.onError((ex) -> {
log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage());
sseEmitters.remove(finalSessionId);
});

return emitter;
.name("connected")
.data(connectionData));

// 设置连接事件处理
String finalSessionId = sessionId;
emitter.onCompletion(() -> {
log.info("SSE连接完成: {}", finalSessionId);
});

emitter.onTimeout(() -> {
log.info("SSE连接超时: {}", finalSessionId);
sseEmitters.remove(finalSessionId);
});

emitter.onError((ex) -> {
log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage());
sseEmitters.remove(finalSessionId);
});

return emitter;

} catch (Exception e) {
log.error("建立SSE连接失败: {}", e.getMessage());
SseEmitter errorEmitter = new SseEmitter(Long.MAX_VALUE);
try {
errorEmitter.send(SseEmitter.event()
.name("error")
.data("连接建立失败: " + e.getMessage()));
} catch (IOException ioException) {
log.error("发送错误消息失败: {}", ioException.getMessage());
}
return errorEmitter;
}
}

/**
Expand Down Expand Up @@ -131,15 +190,50 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
request.getNodeId(),
request.getQuestionType());

// 1. 会话管理和验证
// 1. 强制要求sessionId
if (request.getSessionId() == null || request.getSessionId().trim().isEmpty()) {
log.warn("缺少必需的sessionId参数");
return ResponseEntity.badRequest().body("sessionId参数是必需的");
}

// 2. 会话管理和验证
String userId = request.getUserId();
if (userId == null || userId.trim().isEmpty()) {
log.warn("缺少必需的userId参数");
return ResponseEntity.badRequest().body("userId参数是必需的");
}

ConversationSession session = sessionManagementService.getOrCreateSession(userId, request.getSessionId());
// 3. 验证会话是否存在
ConversationSession session = sessionManagementService.validateAndGetSession(userId, request.getSessionId());
if (session == null) {
log.warn("会话不存在或无效 - 用户ID: {}, 会话ID: {}", userId, request.getSessionId());
return ResponseEntity.badRequest().body("会话不存在或无效");
}

// 2. 验证nodeId是否属于该会话
if (request.getNodeId() != null && !sessionManagementService.validateNodeId(session.getSessionId(), request.getNodeId())) {
log.warn("无效的节点ID - 会话: {}, 节点: {}", session.getSessionId(), request.getNodeId());
return ResponseEntity.badRequest().body("无效的节点ID");
// 4. nodeId验证逻辑
String nodeId = request.getNodeId();
if (nodeId == null || nodeId.trim().isEmpty()) {
// nodeId为空,表示这是新建会话的第一个问题
if (session.getQaTree() != null && session.getQaTree().getRoot() != null) {
log.warn("会话已存在qaTree,但nodeId为空 - 会话: {}", session.getSessionId());
return ResponseEntity.badRequest().body("现有会话必须提供nodeId");
}
log.info("新建会话的第一个问题 - 会话: {}", session.getSessionId());
} else if ("root".equals(nodeId)) {
// nodeId为'root',表示这是根节点的回答
if (session.getQaTree() == null || session.getQaTree().getRoot() == null) {
log.info("根节点回答,但qaTree未初始化 - 会话: {}", session.getSessionId());
// 允许继续处理,后续会创建qaTree
} else {
log.info("根节点回答 - 会话: {}", session.getSessionId());
}
} else {
// nodeId不为空且不是'root',验证是否属于该会话
if (!sessionManagementService.validateNodeId(session.getSessionId(), nodeId)) {
log.warn("无效的节点ID - 会话: {}, 节点: {}", session.getSessionId(), nodeId);
return ResponseEntity.badRequest().body("无效的节点ID");
}
log.info("更新现有节点 - 会话: {}, 节点: {}", session.getSessionId(), nodeId);
}

// 3. 验证答案格式
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

import io.github.timemachinelab.core.question.BaseQuestion;
import io.github.timemachinelab.core.question.InputQuestion;
import io.github.timemachinelab.core.question.SingleChoiceQuestion;
import io.github.timemachinelab.core.question.MultipleChoiceQuestion;
import io.github.timemachinelab.core.question.FormQuestion;
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
import org.springframework.stereotype.Component;

import java.util.List;

@Component
public class QaTreeDomain {

Expand All @@ -15,9 +21,72 @@ public QaTree createTree(String userStartQuestion) {

return new QaTree(startNode);
}

/**
* 使用ConversationSession的自增ID创建QaTree
* @param userStartQuestion 用户开始问题
* @param session 会话对象,用于获取自增ID
* @return 创建的QaTree
*/
public QaTree createTree(String userStartQuestion, ConversationSession session) {
InputQuestion startQA = new InputQuestion();
startQA.setQuestion(userStartQuestion);
startQA.setAnswer(userStartQuestion);
// 使用会话的自增ID创建根节点
String rootNodeId = session.getNextNodeId();
QaTreeNode startNode = new QaTreeNode(startQA, rootNodeId);

return new QaTree(startNode);
}

public QaTree appendNode(QaTree tree, String parentId, BaseQuestion qa) {
tree.addNode(parentId, new QaTreeNode(qa));
return tree;
}

/**
* 使用ConversationSession的自增ID向QaTree添加节点
* @param tree QA树
* @param parentId 父节点ID
* @param qa 问题对象
* @param session 会话对象,用于获取自增ID
* @return 更新后的QaTree
*/
public QaTree appendNode(QaTree tree, String parentId, BaseQuestion qa, ConversationSession session) {
String nodeId = session.getNextNodeId();
tree.addNode(parentId, new QaTreeNode(qa, nodeId));
return tree;
}

/**
* 更新指定节点的答案
* @param tree QA树
* @param nodeId 节点ID
* @param answer 新的答案内容
* @return 是否更新成功
*/
public boolean updateNodeAnswer(QaTree tree, String nodeId, Object answer) {
QaTreeNode node = tree.getNodeById(nodeId);
if (node == null) {
return false;
}

BaseQuestion qa = node.getQa();
if (qa == null) {
return false;
}

// 根据问题类型设置答案
if (qa instanceof InputQuestion) {
((InputQuestion) qa).setAnswer((String) answer);
} else if (qa instanceof SingleChoiceQuestion) {
((SingleChoiceQuestion) qa).setAnswer((List<String>) answer);
} else if (qa instanceof MultipleChoiceQuestion) {
((MultipleChoiceQuestion) qa).setAnswer((List<String>) answer);
} else if (qa instanceof FormQuestion) {
((FormQuestion) qa).setAnswer((List<FormQuestion.AnswerItem>) answer);
}

return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,23 @@ public QaTreeNode(BaseQuestion qa) {
this.children = new HashMap<>();
this.qa = qa;
}

/**
* 使用指定ID创建节点的构造函数
* @param qa 问题对象
* @param nodeId 指定的节点ID
*/
public QaTreeNode(BaseQuestion qa, String nodeId) {
this.id = nodeId;
this.children = new HashMap<>();
this.qa = qa;
}

/**
* @deprecated 此方法使用UUID生成节点ID,不符合自增ID规范。
* 请使用 QaTreeDomain.appendNode(tree, parentId, qa, session) 方法代替。
*/
@Deprecated
public void append(BaseQuestion qa) {
QaTreeNode node = new QaTreeNode(qa);
this.append(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class ConversationService {


public void processUserMessage(String userId, String userMessage, Consumer<QuestionGenerationOperation.QuestionGenerationResponse> sseCallback) {
ConversationSession session = sessionManagementService.getUserCurrentSession(userId);
ConversationSession session = sessionManagementService.getUserLatestSession(userId);
if (session == null) {
log.warn("会话不存在");
return;
Expand Down
Loading
Loading