diff --git a/app.py b/app.py index c471351..7e63e52 100644 --- a/app.py +++ b/app.py @@ -8,21 +8,93 @@ import numpy as np from typing import Dict, Any, List, Optional from scipy import stats - +import re from flask import Flask, request, jsonify, render_template, send_from_directory from flask_cors import CORS # 导入配置和核心功能模块 -from config import KNOWLEDGE_BASE_DIR, AI_STUDIO_API_KEY +from config import ( + KNOWLEDGE_BASE_DIR, AI_STUDIO_API_KEY, + ML_KEYWORDS as APP_ML_KEYWORDS, # 使用别名避免与局部变量冲突 + ML_OPS_KEYWORDS as APP_ML_OPS_KEYWORDS, + UNCERTAINTY_PHRASES, RAG_SCORE_THRESHOLD, RAG_ANSWER_MIN_LENGTH +) from rag_core import query_rag, initialize_rag_system, direct_query_llm from ml_agents import query_ml_agent - +import logging +logger = logging.getLogger(__name__) # 导入增强版RAG和ML集成功能 from rag_core_enhanced import enhanced_query_rag, enhanced_direct_query_llm from ml_agents_enhanced import enhanced_query_ml_agent from advanced_feature_analysis import integrate_ml_with_rag - +from werkzeug.utils import secure_filename # 新增导入 # Helper functions moved to the top +def extract_and_parse_json_from_llm(llm_response_str: str, endpoint_name: str = "LLM_JSON_Parser") -> tuple[ + Optional[dict], Optional[str]]: + """ + 从LLM的响应字符串中提取并解析JSON。 + + Args: + llm_response_str: LLM返回的原始字符串。 + endpoint_name: 调用此函数的端点名称,用于日志记录。 + + Returns: + A tuple (parsed_json, error_message). + If successful, parsed_json is the dict, error_message is None. + If failed, parsed_json is None, error_message contains the error. + """ + if not llm_response_str: + logger.warning(f"[{endpoint_name}] LLM响应为空字符串。") + return None, "LLM响应为空。" + + logger.debug(f"[{endpoint_name}] 原始LLM响应 (前500字符): {llm_response_str[:500]}") + + extracted_json_str = None + + # 1. 尝试匹配Markdown代码块 ```json ... ``` + match = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", llm_response_str, re.DOTALL) + if match: + extracted_json_str = match.group(1) + logger.debug(f"[{endpoint_name}] 从Markdown代码块中提取到JSON字符串。") + else: + # 2. 如果没有Markdown块,尝试查找最外层的 '{' 和 '}' + # 这需要更小心,因为LLM的文本中可能包含其他花括号 + # 一个稍微健壮一点的方法是找到第一个 '{' 和最后一个 '}' + # 但这仍然不完美,如果JSON本身被包裹在更多文本中且文本中也有花括号 + json_start = llm_response_str.find('{') + json_end = llm_response_str.rfind('}') + if json_start != -1 and json_end != -1 and json_end > json_start: + # 尝试验证括号是否匹配,这比较复杂,这里简化处理 + # 我们可以先假设这个提取是初步的 + potential_json = llm_response_str[json_start: json_end + 1] + # 尝试直接解析这个初步提取的部分 + try: + json.loads(potential_json) # 尝试解析,如果成功,就用它 + extracted_json_str = potential_json + logger.debug(f"[{endpoint_name}] 通过查找首尾花括号提取到潜在JSON字符串。") + except json.JSONDecodeError: + # 如果初步提取的无法解析,回退到使用整个字符串,寄希望于它本身就是JSON + logger.warning( + f"[{endpoint_name}] 初步提取的JSON '{potential_json[:100]}...' 无法解析,将尝试解析整个LLM响应。") + extracted_json_str = llm_response_str # Fallback + else: + # 如果连首尾花括号都找不到,直接用原始字符串尝试 + extracted_json_str = llm_response_str + logger.debug(f"[{endpoint_name}] 未找到明确的JSON结构标记,将尝试解析整个LLM响应。") + + if not extracted_json_str: # 应该不会到这里,因为上面总会给 extracted_json_str 赋值 + logger.error(f"[{endpoint_name}] 无法提取任何JSON候选字符串。") + return None, "无法从LLM响应中提取JSON内容。" + + try: + parsed_json = json.loads(extracted_json_str) + logger.info(f"[{endpoint_name}] 成功解析JSON。") + return parsed_json, None + except json.JSONDecodeError as e: + err_msg = f"LLM返回的内容无法解析为有效的JSON。错误: {e}. 内容 (前500字符): {extracted_json_str[:500]}" + logger.error(f"[{endpoint_name}] {err_msg}") + # 注意:在实际返回给前端的错误信息中,可能不需要包含具体的解码错误 e,以免泄露过多细节 + return None, "大模型未能返回有效的JSON格式。请检查Prompt或重试。" # 返回给前端的通用错误 def is_rag_result_poor(query, rag_result): """ 评估RAG结果质量是否较差 @@ -62,11 +134,19 @@ def is_rag_result_poor(query, rag_result): return True return False - +UPLOADS_DIR = os.path.join(os.getcwd(), 'uploads') +MODELS_DIR = os.path.join(os.getcwd(), 'ml_models') # 如果您有模型存储目录 +ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'json'} # 定义允许的文件扩展名 +def allowed_file(filename): + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS app = Flask(__name__) # Flask会自动查找同级的 'templates' 文件夹 CORS(app) +app.config['UPLOADS_DIR'] = UPLOADS_DIR +app.config['MODELS_DIR'] = MODELS_DIR +app.config['MAX_CONTENT_LENGTH'] = 32 * 1024 * 1024 # --- 日志配置 --- # 基本配置,确保在 app.run() 之前设置,或者由 Flask 的 debug 模式自动处理 @@ -92,73 +172,6 @@ def index(): """渲染主HTML页面。""" return render_template('index.html') -@app.route('/query', methods=['POST']) -def query_endpoint(): - """处理用户查询的端点""" - try: - data = request.json - query = data.get('query', '') - - if not query: - return jsonify({"error": "请提供查询文本"}), 400 - - app.logger.info(f"接收到查询请求: {query}") - - # 机器学习相关查询检测 - ml_keywords = [ - '机器学习', '模型', '训练', '预测', '分类', '回归', '聚类', - '随机森林', '决策树', '线性回归', '逻辑回归', 'KNN', 'SVM', - '朴素贝叶斯', 'K-Means', '数据', '特征', '准确率', 'MSE', 'RMSE' - ] - # 操作类关键词 - ml_ops_keywords = ['训练', '预测', '比较', '评估', '构建', '解释', '自动', '集成', '版本', '分析', '推荐'] - - is_ml_query = any(keyword.lower() in query.lower() for keyword in ml_keywords) - is_ml_ops = any(op in query for op in ml_ops_keywords) - - # 1. 操作类问题优先走增强版ML Agent - if is_ml_query and is_ml_ops: - try: - app.logger.info("检测到机器学习操作类查询,使用增强版ML Agent处理") - result = enhanced_query_ml_agent(query, use_existing_model=True) - return jsonify(result) - except Exception as e: - app.logger.error(f"增强版ML Agent处理时出错,回退到RAG: {str(e)}") - # 尝试使用标准ML Agent - try: - app.logger.info("尝试使用标准ML Agent处理") - result = query_ml_agent(query) - return jsonify(result) - except Exception as e2: - app.logger.error(f"标准ML Agent处理时出错,回退到RAG: {str(e2)}") - # 机器学习处理失败时回退到RAG系统 - - # 2. 专业知识问答优先走增强版RAG - app.logger.info("使用增强版RAG系统处理常规/知识类查询") - try: - # 尝试使用增强版RAG处理 - result = enhanced_query_rag(query) - except Exception as e: - app.logger.warning(f"增强版RAG处理失败,回退到标准RAG: {str(e)}") - # 回退到标准RAG - result = query_rag(query) - - # 3. RAG效果不佳时兜底增强版LLM - if is_rag_result_poor(query, result): - app.logger.info("RAG结果质量不佳,切换到直接大模型回答") - try: - direct_llm_response = enhanced_direct_query_llm(query) - except Exception as e: - app.logger.warning(f"增强版LLM处理失败,回退到标准LLM: {str(e)}") - direct_llm_response = direct_query_llm(query) - - result["answer"] = direct_llm_response["answer"] - result["is_direct_answer"] = direct_llm_response.get("is_direct_answer", True) - return jsonify(result) - except Exception as e: - app.logger.error(f"处理查询时出错: {str(e)}") - return jsonify({"error": f"服务器错误: {str(e)}"}), 500 - @app.route('/api/models/ml_models', methods=['GET']) def get_ml_models(): """ @@ -198,177 +211,162 @@ def get_ml_models(): except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route('/api/chat', methods=['POST']) def chat_endpoint(): - """处理聊天请求的API端点。""" data = request.get_json() if not data or 'query' not in data: app.logger.warning("API请求缺少 'query' 字段。请求体: %s", data) return jsonify({"error": "请求体中缺少 'query' 字段"}), 400 - user_query = data.get('query') # 使用 .get() 更安全 - use_existing_model = data.get('use_existing_model', True) # 默认为True,优先使用现有模型 - if not isinstance(user_query, str) or not user_query.strip(): - app.logger.warning(f"API接收到无效查询: '{user_query}' (类型: {type(user_query)})") - return jsonify({"error": "查询必须是非空字符串"}), 400 - - app.logger.info(f"API接收到查询: '{user_query}'") - ml_keywords = ['机器学习', '模型', '训练', '预测', '回归', '分类', 'ML', '决策树', '随机森林', - '线性回归', '逻辑回归', '数据分析', '特征', '权重', '参数', '准确率', 'accuracy', - 'precision', 'recall'] - ml_ops_keywords = ['训练', '预测', '比较', '评估', '构建', '解释', '自动', '集成', '版本', '分析', '推荐'] - is_ml_query = any(keyword in user_query for keyword in ml_keywords) - is_ml_ops = any(op in user_query for op in ml_ops_keywords) + user_query = data.get('query', '').strip() + mode = data.get('mode') # 'data_analysis' 或 'general_llm' + + # 从请求中获取前端可能传递的上下文信息 + data_preview = data.get('data_preview') # 前端传来的数据预览 (list of dicts) + target_column = data.get('target_column') + selected_model_name = data.get('model_name') # 前端选择的模型 + data_path = data.get('data_path') # 如果需要完整数据路径 + + if not user_query: + app.logger.warning("API接收到空查询") + return jsonify({"error": "查询不能为空"}), 400 + + app.logger.info(f"API接收到查询: '{user_query[:100]}...', 模式: {mode}") + try: - # 优先处理通用大模型回答模式 - # 前端实际传递的通用大模型模式的 mode 值为 'general_llm' - if data.get('mode') == 'general_llm': - app.logger.info("检测到通用大模型回答模式,直接调用LLM API") - try: - direct_llm_response = enhanced_direct_query_llm(user_query) - return jsonify({ - "answer": direct_llm_response.get("answer", "未能获取回答。"), - "source_documents": direct_llm_response.get("source_documents", []), - "is_ml_query": False, - "is_direct_answer": True, - "model_used": direct_llm_response.get("model_name", "General LLM (Enhanced)") - }) - except Exception as e_enhanced_llm: - app.logger.error(f"增强版通用大模型LLM调用失败: {str(e_enhanced_llm)},尝试标准LLM", exc_info=True) - try: - direct_llm_response = direct_query_llm(user_query) - return jsonify({ - "answer": direct_llm_response.get("answer", "未能获取回答。"), - "source_documents": direct_llm_response.get("source_documents", []), - "is_ml_query": False, - "is_direct_answer": True, - "model_used": "General LLM (Standard)" - }) - except Exception as e_standard_llm: - app.logger.error(f"标准通用大模型LLM调用也失败: {str(e_standard_llm)}", exc_info=True) - return jsonify({"error": f"通用大模型处理时出错: {str(e_standard_llm)}"}), 500 - - # 检查是否为教程生成请求 - elif (data.get('mode') == 'data_analysis' and - data.get('data_preview') and - data.get('model_name') and - data.get('target_column')): - - app.logger.info(f"检测到教程生成请求: 模型 '{data.get('model_name')}', 目标列 '{data.get('target_column')}'") - llm_ml_context = { - 'data_preview': data.get('data_preview'), - 'model_name': data.get('model_name'), - 'target_column': data.get('target_column'), - 'generate_tutorial': True - } - + # 功能 4: 用户使用通用大模型模式 + if mode == 'general_llm': + app.logger.info("处理通用大模型模式查询") + # 直接调用LLM,Prompt可以简单包装或直接使用用户查询 + prompt = f"请回答以下问题:\n{user_query}" + llm_response = enhanced_direct_query_llm(prompt) # 假设它返回 {"answer": "...", ...} + return jsonify({ + "answer": llm_response.get("answer", "未能获取回答。"), + "source_documents": llm_response.get("source_documents", []), # RAG核心的LLM也可能返回空 + "is_ml_query": False, # 通常通用查询不是特定ML操作 + "is_direct_answer": True, + "model_used": llm_response.get("model_name", "General LLM") + }), 200 + + # --- 数据分析模式 (mode == 'data_analysis') --- + # 功能 3: 用户上传数据集和选择目标列和模型,然后为其生成具体的包含代码的教程 + # 我们通过一个关键词或前端特定标记来识别教程生成请求 + # 假设前端会在query中包含“生成教程”或通过一个额外参数标记 + is_tutorial_request = ("教程" in user_query.lower() or "generate_tutorial" in data) and \ + data_preview and selected_model_name and target_column + + if is_tutorial_request: + app.logger.info(f"处理教程生成请求: 模型 '{selected_model_name}', 目标列 '{target_column}'") + prompt_parts = [ + f"请为以下场景生成一个详细的Python机器学习教程,使用语言为中文:", + f"用户原始问题(供参考,主要按以下要求生成教程): {user_query}", + f"要使用的模型: {selected_model_name}", + f"目标预测列: {target_column}", + "教程应包含清晰的步骤、Python代码示例和必要的解释。代码应尽可能通用,并使用常见的库如 pandas 和 scikit-learn。", + "步骤应包括(但不限于):" + " 1. 简介和目标说明。", + " 2. 数据加载与探索性数据分析 (EDA):假设用户已有一个Pandas DataFrame,其列名和数据类型可参考以下数据预览。", + " 3. 数据预处理:根据数据预览和常见场景,讨论可能需要的预处理步骤(如处理缺失值、分类特征编码、数值特征缩放)。请提供代码片段作为示例。", + " 4. 特征工程(如果适用)。", + " 5. 将数据集拆分为特征 (X) 和目标 (y),然后划分为训练集和测试集。", + f" 6. 模型初始化与训练:实例化 '{selected_model_name}' 模型并进行训练。", + " 7. 使用模型进行预测。", + " 8. 模型评估:选择适合该模型和任务(分类/回归)的评估指标,并解释如何解读它们。", + " 9. 总结和后续步骤建议。", + "数据预览(前几行,用于理解数据结构和生成相关代码示例):", + json.dumps(data_preview, indent=2, ensure_ascii=False), + "\n请确保代码块使用Markdown格式正确标识。" + ] + tutorial_prompt = "\n".join(prompt_parts) + app.logger.debug(f"教程生成Prompt (部分): {tutorial_prompt[:500]}...") + + llm_response = enhanced_direct_query_llm(tutorial_prompt) + return jsonify({ + "answer": llm_response.get("answer", "未能生成教程内容。"), + "source_documents": [], # 教程通常不依赖RAG源文档 + "is_ml_query": True, + "is_tutorial": True, + "ml_model_used": selected_model_name, + "target_column_for_tutorial": target_column + }), 200 + + # 功能 1 & 2: 用户基于上传的数据进行提问 + if data_preview: # 必须有数据预览才能进行这类分析 + app.logger.info("处理基于数据的分析查询") + prompt_parts = [f"用户问题: {user_query}\n"] + prompt_parts.append("请根据以下提供的数据预览信息来回答用户的问题。") + prompt_parts.append("数据预览 (前几行):") + prompt_parts.append(json.dumps(data_preview, indent=2, ensure_ascii=False) + "\n") + + if target_column: + prompt_parts.append(f"用户已指定的目标列 (用于预测或分析相关性): '{target_column}'\n") + if selected_model_name: + prompt_parts.append(f"用户当前选择或提及的模型是: '{selected_model_name}'\n") + + # 根据问题类型调整指示 + if "适合用什么模型进行分析" in user_query or "推荐模型" in user_query: + prompt_parts.append( + "请基于数据的特点(如列名、数据类型暗示、值的分布等)推荐合适的机器学习模型,并解释推荐理由。") + elif "哪些特征最重要" in user_query and target_column: + prompt_parts.append( + f"请分析对于预测目标列 '{target_column}',数据预览中的哪些其他列(特征)可能最重要,并说明判断依据。") + else: + prompt_parts.append("请针对用户的问题,结合数据预览给出专业的分析和回答。") + + data_context_prompt = "".join(prompt_parts) + app.logger.debug(f"数据上下文Prompt (部分): {data_context_prompt[:500]}...") + + llm_response = enhanced_direct_query_llm(data_context_prompt) + return jsonify({ + "answer": llm_response.get("answer", "未能基于数据分析回答。"), + "source_documents": [], + "is_ml_query": True, # 假设这类问题与ML相关 + "data_context_used": True, + "model_used": llm_response.get("model_name", "Contextual LLM") + }), 200 + + # 如果是数据分析模式,但没有足够上下文(如数据预览),则可能无法很好回答 + # 可以尝试RAG或通用ML Agent,或者提示用户上传数据 + app.logger.warning(f"数据分析模式查询 '{user_query}',但缺少足够的数据上下文 (如data_preview)。") + # 尝试使用ML Agent (如果问题看起来是操作性的) + is_ml_query = any(keyword.lower() in user_query.lower() for keyword in APP_ML_KEYWORDS) + is_ml_ops = any(op.lower() in user_query.lower() for op in APP_ML_OPS_KEYWORDS) + if is_ml_query and is_ml_ops and enhanced_query_ml_agent: try: - # user_query 也传递给LLM,以便它了解用户的原始意图 - direct_llm_response = enhanced_direct_query_llm(user_query, llm_ml_context) + app.logger.info("尝试使用ML Agent处理(无数据上下文的操作类查询)") + # 注意:这里的ML Agent可能无法执行需要数据的操作 + agent_result = enhanced_query_ml_agent(user_query, + use_existing_model=data.get('use_existing_model', True)) return jsonify({ - "answer": direct_llm_response.get("answer", "未能生成教程内容。"), - "source_documents": [], - "is_ml_query": True, - "is_tutorial": True, - "ml_model_used": data.get('model_name') - }) - except Exception as e: - app.logger.error(f"教程生成LLM调用失败: {str(e)}", exc_info=True) - return jsonify({"error": f"生成教程时出错: {str(e)}"}), 500 - - elif is_ml_query and is_ml_ops: - app.logger.info(f"检测到机器学习操作类查询,将使用增强版ML Agent处理") - try: - # 尝试使用增强版ML代理 - result = enhanced_query_ml_agent(user_query, use_existing_model=use_existing_model) - except Exception as e: - app.logger.warning(f"增强版ML代理处理失败,回退到标准ML代理: {str(e)}") - # 回退到标准ML代理 - result = query_ml_agent(user_query, use_existing_model=use_existing_model) - - # 返回结果,保留特征分析数据和预测结果 - response_data = { - "answer": result["answer"], - "source_documents": [], - "is_ml_query": True, - "feature_analysis": result.get("feature_analysis", {}), - "ml_model_used": result.get("model_used", "未知模型") - } - # 如果结果中包含预测,添加到响应中 - if "prediction" in result: - response_data["prediction"] = result["prediction"] - return jsonify(response_data) + "answer": agent_result.get("answer", "ML Agent未能处理此请求。"), + "is_ml_query": True, + "ml_model_used": agent_result.get("model_used", "ML Agent") + # 其他 agent_result 中的字段,如 feature_analysis, prediction 等 + }), 200 + except Exception as e_agent: + app.logger.error(f"ML Agent处理失败: {str(e_agent)}", exc_info=True) + # 继续执行后续的RAG/LLM兜底 + + # 默认兜底:尝试RAG,如果RAG不好再用纯LLM (主要用于知识性问答) + app.logger.info(f"数据分析模式查询 '{user_query}' 无特定处理路径,尝试RAG") + rag_response = enhanced_query_rag(user_query) + if not is_rag_result_poor(user_query, rag_response): # 使用您已有的 is_rag_result_poor + app.logger.info("RAG结果尚可") + return jsonify(rag_response), 200 else: - app.logger.info(f"使用增强版RAG系统处理常规/知识类查询") - try: - # 尝试使用增强版RAG处理,启用机器学习集成 - result = enhanced_query_rag(user_query, ml_integration=True) - except Exception as e: - app.logger.warning(f"增强版RAG处理失败,回退到标准RAG: {str(e)}") - # 回退到标准RAG - result = query_rag(user_query) - - result["is_ml_query"] = False - - # 检查是否需要进行机器学习集成 - if is_ml_query and not is_ml_ops and "预测" in user_query: - app.logger.info("检测到预测类查询,尝试集成机器学习模型结果") - try: - # 提取可能的预测目标和特征 - from rag_core_enhanced import extract_prediction_info, find_suitable_model, make_prediction_with_model - prediction_target, features = extract_prediction_info(user_query) - - if prediction_target and features: - # 查找适合该预测任务的模型 - model_name = find_suitable_model(prediction_target) - - if model_name: - # 加载模型并进行预测 - model_result = make_prediction_with_model(model_name, features) - - # 将模型预测结果与RAG结果集成 - result = integrate_ml_with_rag(result, model_name, { - "prediction": model_result.get("predictions"), - "feature_importance": model_result.get("feature_importance", {}), - "model_metrics": model_result.get("metrics", {}) - }) - except Exception as e: - app.logger.warning(f"机器学习集成失败: {str(e)}") - - # 如果RAG结果质量不佳,使用增强版LLM - if is_rag_result_poor(user_query, result): - app.logger.info("RAG结果质量不佳,切换到直接大模型回答") - try: - # 如果有机器学习相关信息,将其传递给增强版LLM - ml_context = None - if result.get("ml_enhanced") or result.get("feature_analysis"): - ml_context = { - "model_name": result.get("ml_model_used", "未知模型"), - "prediction": result.get("prediction"), - "feature_importance": result.get("feature_analysis", {}).get("feature_importance", {}), - "model_metrics": result.get("model_metrics", {}) - } - - direct_llm_response = enhanced_direct_query_llm(user_query, ml_context) - except Exception as e: - app.logger.warning(f"增强版LLM处理失败,回退到标准LLM: {str(e)}") - direct_llm_response = direct_query_llm(user_query) - - result["answer"] = direct_llm_response["answer"] - result["is_direct_answer"] = direct_llm_response.get("is_direct_answer", True) - result["ml_enhanced_llm"] = direct_llm_response.get("ml_enhanced", False) - - # 如果结果中包含预测、模型指标或特征重要性,添加到响应中 - if "prediction" in result: - result["prediction"] = result["prediction"] - if "model_metrics" in result: - result["model_metrics"] = result["model_metrics"] - if "feature_importance" in result: - result["feature_importance"] = result["feature_importance"] - - return jsonify(result) + app.logger.info("RAG结果不佳,转通用LLM处理") + prompt = f"请回答以下问题:\n{user_query}" + llm_response = enhanced_direct_query_llm(prompt) + return jsonify({ + "answer": llm_response.get("answer", "未能获取回答。"), + "source_documents": llm_response.get("source_documents", []), + "is_direct_answer": True, + "model_used": llm_response.get("model_name", "Fallback LLM") + }), 200 + except Exception as e: - app.logger.error(f"/api/chat 接口发生错误: {e}", exc_info=True) + app.logger.error(f"/api/chat 接口发生错误: {str(e)}", exc_info=True) return jsonify({"error": f"服务器内部错误,请稍后重试或联系管理员。"}), 500 @app.route('/api/rebuild_vector_store', methods=['POST']) @@ -442,6 +440,231 @@ def train_model_endpoint(): app.logger.error(f"/api/ml/train 接口发生错误: {e}", exc_info=True) return jsonify({"error": f"训练模型时发生错误: {str(e)}"}), 500 + +# 在 app.py 的路由部分添加以下新端点 + +# 功能 5.2: 模型比较 (模拟) +@app.route('/api/simulate_model_comparison', methods=['POST']) +def simulate_model_comparison_endpoint(): + data = request.get_json() + if not data or not all(k in data for k in ['model_names', 'test_data_identifier', 'target_column']): + app.logger.warning(f"模拟模型比较请求缺少参数: {data}") + return jsonify({"error": "缺少必要参数 (model_names, test_data_identifier, target_column)"}), 400 + + model_names_list = data['model_names'] + if not isinstance(model_names_list, list) or len(model_names_list) < 2: + return jsonify({"error": "model_names 必须是至少包含两个模型的列表"}), 400 + + prompt_parts = [ + "请模拟以下机器学习模型的比较过程,并以中文回答:", + f"模型列表:{', '.join(model_names_list)}", + f"测试数据集标识:'{data['test_data_identifier']}' (例如:'当前上传的关于客户流失预测的数据' 或 '公开的鸢尾花分类数据集')", + f"目标列:'{data['target_column']}'\n", + "要求:", + "1. 为列表中的每个模型生成一组合理的、符合其典型应用场景的模拟评估指标。", + " - 如果目标列暗示分类任务(例如,目标列是字符型或少数唯一值),请使用分类指标如:准确率 (Accuracy), 精确率 (Precision), 召回率 (Recall), F1分数 (F1-score)。数值请在0.6到0.98之间随机模拟。", + " - 如果目标列暗示回归任务(例如,目标列是连续数值型),请使用回归指标如:R²分数 (R2 Score), 均方误差 (MSE), 平均绝对误差 (MAE)。R2分数在0.5到0.95之间,MSE/MAE根据常见场景模拟。", + "2. 对模拟结果进行简短总结,指出哪个模型在模拟中可能表现更优,并给出推测的理由。", + "3. 以严格的JSON格式返回结果,结构如下:", + """ +{ + "comparison_results": [ + { + "model_name": "模型A的名称", + "metrics": {"指标1": "模拟值1", "指标2": "模拟值2"} + }, + { + "model_name": "模型B的名称", + "metrics": {"指标1": "模拟值1", "指标2": "模拟值2"} + } + // ... 更多模型 + ], + "summary": "这里是模拟的总结文本...", + "test_data_info": { + "identifier": "{data['test_data_identifier']}", + "simulated_rows": "例如:约1000行", + "simulated_features": "例如:约10个特征" + } +} +""", + "请确保'metrics'对象中的值是数值类型(如果适用,如准确率)或字符串。请务必确保您的整个输出就是一个单一的、完整且严格符合上述结构的JSON对象。不要在JSON对象之前或之后包含任何其他文本、注释、解释或Markdown的```json ```标记,直接输出JSON本身。" + ] + comparison_prompt = "\n".join(prompt_parts) + app.logger.info(f"模拟模型比较Prompt (部分): {comparison_prompt[:300]}...") + + try: + # enhanced_direct_query_llm 应该返回一个字典,其中 "answer" 键包含LLM的原始文本输出 + llm_raw_output_dict = enhanced_direct_query_llm(comparison_prompt) + if not llm_raw_output_dict or "answer" not in llm_raw_output_dict: + app.logger.error("enhanced_direct_query_llm 未返回预期的包含 'answer' 的字典。") + return jsonify({"error": "调用大模型时发生内部错误 (无响应)。"}), 500 + + llm_response_str = llm_raw_output_dict.get("answer", "") + + simulated_results, error_msg = extract_and_parse_json_from_llm(llm_response_str, "ModelComparison") + + if error_msg: # 如果解析失败 + return jsonify({ + "error": error_msg, # 使用辅助函数返回的错误信息 + "raw_llm_response": llm_response_str # 仍然返回原始响应 + }), 500 + + # 基本的结构验证 (simulated_results 不为 None) + if not isinstance(simulated_results, dict) or \ + "comparison_results" not in simulated_results or \ + not isinstance(simulated_results["comparison_results"], list) or \ + "summary" not in simulated_results: + app.logger.error(f"LLM返回的模拟比较结果JSON结构不符合预期: {simulated_results}") + return jsonify({ + "error": "大模型返回的模拟比较结果JSON结构不正确。", + "parsed_response": simulated_results, # 返回已解析(但结构错误)的内容 + "raw_llm_response": llm_response_str + }), 500 + + # 可以添加更细致的结构验证,例如检查 comparison_results 列表中的元素 + for item in simulated_results["comparison_results"]: + if not isinstance(item, dict) or "model_name" not in item or "metrics" not in item: + app.logger.error(f"模拟比较结果中 'comparison_results' 列表内元素结构错误: {item}") + return jsonify({ + "error": "模拟比较结果内部数据结构不正确。", + "parsed_response": simulated_results, + "raw_llm_response": llm_response_str + }), 500 + + return jsonify(simulated_results), 200 + except Exception as e: + app.logger.error(f"模拟模型比较过程中发生错误: {str(e)}", exc_info=True) + return jsonify({"error": f"模拟模型比较时发生内部错误: {str(e)}"}), 500 + + +# 功能 5.3: 集成模型构建 (模拟) +@app.route('/api/simulate_ensemble_building', methods=['POST']) +def simulate_ensemble_building_endpoint(): + data = request.get_json() + if not data or not all(k in data for k in ['base_models', 'ensemble_type', 'ensemble_name']): + app.logger.warning(f"模拟集成构建请求缺少参数: {data}") + return jsonify({"error": "缺少必要参数 (base_models, ensemble_type, ensemble_name)"}), 400 + + base_models_list = data['base_models'] + ensemble_type = data['ensemble_type'] + ensemble_name = data['ensemble_name'] + + if not isinstance(base_models_list, list) or len(base_models_list) < 2: + return jsonify({"error": "base_models 必须是至少包含两个模型的列表"}), 400 + if not ensemble_name.strip(): + return jsonify({"error": "ensemble_name 不能为空"}), 400 + + prompt_parts = [ + f"请模拟构建一个名为 '{ensemble_name}' 的集成学习模型,并以中文回答:", + f"基础模型列表:{', '.join(base_models_list)}", + f"集成类型:{ensemble_type} (例如:Voting Classifier, Stacking Regressor, BaggingClassifier等)\n", + "请在回答中包含以下内容:", + "1. 一个模拟的“构建成功”或“已创建”的消息。", + f"2. 对这个名为 '{ensemble_name}' 的模拟集成模型的工作原理进行简要描述(根据其类型和基础模型)。", + "3. 列出这个模拟集成模型相对于其基础模型可能的潜在优势。", + "4. 简要描述它可能适用于什么样的数据集或问题场景。", + "5. 提供一些关于这个模拟集成模型的假设性元数据,例如模拟的创建时间戳、组合方式等。", + "请以严格的JSON格式返回结果,结构如下:", + """ +{ + "success": true, + "message": "模拟的构建成功消息,例如:集成模型 '[ensemble_name]' 已成功模拟创建!", + "ensemble_name": "{ensemble_name}", + "ensemble_type": "{ensemble_type}", + "base_models_used": {base_models_list}, + "description": "这里是集成模型工作原理的模拟描述...", + "potential_advantages": "这里是模拟的潜在优势列表或描述...", + "suitable_scenarios": "这里是模拟的适用场景描述...", + "model_info": { + "simulated_created_at": "例如:一个ISO格式的时间戳,如 YYYY-MM-DDTHH:MM:SSZ", + "simulated_combination_method": "例如:对投票分类器是'多数投票'或'加权投票',对Stacking是'使用元学习器组合预测'等" + } +} +""", + "请确保整个响应是单一的、格式正确的JSON对象。" + ] + ensemble_prompt = "\n".join(prompt_parts) + app.logger.info(f"模拟集成构建Prompt (部分): {ensemble_prompt[:300]}...") + + try: + llm_raw_output_dict = enhanced_direct_query_llm(ensemble_prompt) # √ 使用正确的 ensemble_prompt + if not llm_raw_output_dict or "answer" not in llm_raw_output_dict: + app.logger.error("enhanced_direct_query_llm 未返回预期的包含 'answer' 的字典。") + return jsonify({"error": "调用大模型时发生内部错误 (无响应)。"}), 500 + + llm_response_str = llm_raw_output_dict.get("answer", "") + + # --- 修改这里的日志标记 --- + simulated_results, error_msg = extract_and_parse_json_from_llm(llm_response_str, + "EnsembleBuilding") # 修改为 "EnsembleBuilding" + + if error_msg: + return jsonify({ + "error": error_msg, + "raw_llm_response": llm_response_str + }), 500 + + # --- 修改这里的JSON结构验证 --- + expected_keys = ["success", "ensemble_name", "ensemble_type", "base_models_used", + "description", "potential_advantages", "suitable_scenarios", "model_info"] + + if not isinstance(simulated_results, dict): + app.logger.error(f"LLM返回的模拟集成结果不是一个字典: {simulated_results}") # √ 日志文本正确 + return jsonify({ + "error": "大模型返回的模拟集成结果格式不正确 (非字典)。", # √ 错误信息文本正确 + "parsed_response": simulated_results, + "raw_llm_response": llm_response_str + }), 500 + + missing_keys = [key for key in expected_keys if key not in simulated_results] + if missing_keys: + app.logger.error( + f"LLM返回的模拟集成结果JSON缺少关键字段: {missing_keys}. 结果: {simulated_results}") # √ 日志文本正确 + return jsonify({ + "error": f"大模型返回的模拟集成结果JSON缺少必要字段: {', '.join(missing_keys)}。", # √ 错误信息文本正确 + "parsed_response": simulated_results, + "raw_llm_response": llm_response_str + }), 500 + + # 可以选择性地添加对特定字段类型的进一步验证 + if not isinstance(simulated_results.get("success"), bool): + app.logger.error(f"模拟集成结果中 'success' 字段类型错误. 结果: {simulated_results}") + return jsonify({ + "error": "模拟集成结果中 'success' 字段类型非布尔值。", + "parsed_response": simulated_results, "raw_llm_response": llm_response_str + }), 500 + if not isinstance(simulated_results.get("base_models_used"), list): + app.logger.error(f"模拟集成结果中 'base_models_used' 字段类型错误. 结果: {simulated_results}") + return jsonify({ + "error": "模拟集成结果中 'base_models_used' 字段类型非列表。", + "parsed_response": simulated_results, "raw_llm_response": llm_response_str + }), 500 + # ... 可以为 model_info 等其他字段添加类似检查 ... + + # 移除或注释掉针对 comparison_results 的 for 循环验证 + # for item in simulated_results["comparison_results"]: <--- 这个是错误的,应该移除 + + return jsonify(simulated_results), 201 # √ 使用 201 Created + except Exception as e: + app.logger.error(f"模拟集成模型构建过程中发生错误: {str(e)}", exc_info=True) # √ 日志文本正确 + return jsonify({"error": f"模拟集成模型构建时发生内部错误: {str(e)}"}), 500 # √ 错误信息文本正确 + + @app.route('/api/ml/model_versions', methods=['POST']) + def create_model_version_placeholder(): + # data = request.get_json() # 可以接收数据但不处理 + # model_name = data.get('model_name') + # version_info = data.get('version_info') + # app.logger.info(f"接收到创建模型版本请求 (前端模拟): {model_name} - {version_info}") + return jsonify({"success": True, "message": "模型版本信息已在前端记录 (模拟)。"}), 200 + + @app.route('/api/ml/model_versions/', methods=['GET']) + def get_model_versions_placeholder(model_name): + # app.logger.info(f"接收到获取模型版本请求 (前端模拟): {model_name}") + # 模拟返回空列表或一个示例结构 + return jsonify({"success": True, "versions": [], "message": "模型版本历史在前端管理 (模拟)。"}), 200 + + # 对部署相关的API也做类似处理 + @app.route('/api/ml/predict', methods=['POST']) def predict_endpoint(): """使用机器学习模型进行预测的API端点""" @@ -583,9 +806,9 @@ def upload_data_endpoint(): "categorical_columns": categorical_columns, "numerical_columns": numerical_columns, "row_count": len(df), - "preview": df.head(5).to_dict('records') + "preview": df.head(10).to_dict('records') }) - + return jsonify(result), 200 except Exception as e: app.logger.error(f"/api/ml/upload 接口发生错误: {e}", exc_info=True) diff --git a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/data_level0.bin b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/data_level0.bin index e0354e1..9ce549a 100644 Binary files a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/data_level0.bin and b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/data_level0.bin differ diff --git a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/header.bin b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/header.bin index 66a8fb4..ce0068a 100644 Binary files a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/header.bin and b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/header.bin differ diff --git a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/index_metadata.pickle b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/index_metadata.pickle index 7d57986..be57003 100644 Binary files a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/index_metadata.pickle and b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/index_metadata.pickle differ diff --git a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/length.bin b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/length.bin index 57e40b7..2f9e85b 100644 Binary files a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/length.bin and b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/length.bin differ diff --git a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/link_lists.bin b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/link_lists.bin index 842fae3..edf48c6 100644 Binary files a/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/link_lists.bin and b/chroma_db/0868db41-6ebe-49c6-a051-6415cda05a70/link_lists.bin differ diff --git a/chroma_db/chroma.sqlite3 b/chroma_db/chroma.sqlite3 index f1c8d89..ed18c6f 100644 Binary files a/chroma_db/chroma.sqlite3 and b/chroma_db/chroma.sqlite3 differ diff --git a/config.py b/config.py index 34874cd..eeb022e 100644 --- a/config.py +++ b/config.py @@ -40,4 +40,40 @@ CHUNK_OVERLAP = int(CHUNK_SIZE * 0.2) # Embedding API 调用时的批处理大小 (embedding-v1 API限制每次最多16个输入) -EMBEDDING_BATCH_SIZE = 16 \ No newline at end of file +EMBEDDING_BATCH_SIZE = 16 + +# config.py (在文件末尾或合适位置添加以下内容) + +# --- RAG 和 LLM 结果评估相关配置 --- +UNCERTAINTY_PHRASES = [ + "无法找到", "没有相关信息", "未能找到", "无法提供", "不确定", + "我不知道", "无法确定", "没有足够信息", "目前无法回答", + "To", "I cannot", "I don't", "Unable to", "not find", "no information", + "对不起,我无法", "抱歉,我无法" # 添加更多常见的不确定性短语 +] +RAG_SCORE_THRESHOLD = 0.45 # RAG 文档相关性得分阈值 (可根据实际效果调整) +RAG_ANSWER_MIN_LENGTH = 25 # RAG 回答最小长度阈值 (字符数, 可调整) + +# --- 机器学习关键词列表 --- +ML_KEYWORDS = [ + '机器学习', '模型', '训练', '预测', '分类', '回归', '聚类', '算法', '特征', '数据', + '随机森林', '决策树', '支持向量机', 'svm', 'knn', 'k近邻', '逻辑回归', '线性回归', + '神经网络', '深度学习', '朴素贝叶斯', 'k-means', 'xgboost', 'lightgbm', 'catboost', + '准确率', '精确率', '召回率', 'f1分数', 'auc', 'roc', 'mse', 'rmse', 'mae', 'r方', 'r2', + '超参数', '验证集', '测试集', '过拟合', '欠拟合', '特征工程', '降维', 'pca', + 'tensorflow', 'keras', 'pytorch', 'scikit-learn', 'sklearn', 'paddlepaddle', 'paddle' +] + +ML_OPS_KEYWORDS = [ + '训练', '预测', '比较', '评估', '构建', '解释', '优化', '部署', '监控', '保存', '加载', + '选择模型', '调整参数', '分析特征', '生成报告', '自动化', '工作流', + '版本控制', '流水线', 'pipeline', 'finetune', '微调', '自动机器学习', 'automl' +] + +# --- 应用行为相关配置 (示例,您可以按需添加更多) --- +# 例如,上传文件存储位置,虽然您在 app.py 中定义了 UPLOADS_DIR,但也可以考虑放在这里 +# UPLOADS_DIR = os.path.join(os.getcwd(), "uploads") +# MODELS_STORAGE_DIR = os.path.join(os.getcwd(), "ml_models") + +# 默认的预览行数 +DEFAULT_PREVIEW_ROWS = 10 diff --git a/ml_agents.py b/ml_agents.py index 54bcf1a..5608827 100644 --- a/ml_agents.py +++ b/ml_agents.py @@ -13,7 +13,8 @@ from langchain.tools import StructuredTool from langchain_core.tools import Tool -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import Field +from pydantic import BaseModel from langchain.agents import AgentExecutor, create_structured_chat_agent from langchain_core.prompts import PromptTemplate diff --git a/ml_api.log b/ml_api.log index 372ddf7..442fc5c 100644 --- a/ml_api.log +++ b/ml_api.log @@ -1403,3 +1403,211 @@ NameError: name 'List' is not defined 2025-05-22 15:30:58,109 - ml_api_endpoints - INFO - ҵ 2 0 2025-05-22 15:30:58,110 - werkzeug - INFO - 127.0.0.1 - - [22/May/2025 15:30:58] "GET /api/ml/deployments HTTP/1.1" 200 - 2025-05-22 15:30:58,339 - werkzeug - INFO - 127.0.0.1 - - [22/May/2025 15:30:58] "GET /static/favicon.ico HTTP/1.1" 304 - +2025-05-24 16:03:38,111 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 16:03:38,114 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 16:03:38,116 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 16:03:38] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 17:40:53,248 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 17:40:53,248 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 17:40:53,249 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:40:53] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 17:40:54,163 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:40:54] "GET /static/favicon.ico HTTP/1.1" 200 - +2025-05-24 17:41:45,622 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:41:45] "GET / HTTP/1.1" 200 - +2025-05-24 17:41:46,030 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:41:46] "GET /static/js/app.js HTTP/1.1" 200 - +2025-05-24 17:41:46,197 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 17:41:46,198 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 17:41:46,198 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:41:46] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 17:41:46,647 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:41:46] "GET /static/favicon.ico HTTP/1.1" 200 - +2025-05-24 17:42:23,373 - app - INFO - Ϊģ 'linear_regression' ڻ 'staging' ɵIJ˵: /api/predict/linear-regression/e2b391b0 +2025-05-24 17:42:23,373 - ml_api_endpoints - INFO - ʼģ 'linear_regression' staging +2025-05-24 17:42:23,424 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:42:23] "POST /api/ml/deploy HTTP/1.1" 400 - +2025-05-24 17:42:54,618 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:42:54] "POST /api/ml/upload HTTP/1.1" 200 - +2025-05-24 17:42:54,649 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:42:54] "POST /api/ml/analyze HTTP/1.1" 200 - +2025-05-24 17:42:59,929 - app - INFO - APIյѯ: 'ʺʲôģͽз...', ģʽ: data_analysis +2025-05-24 17:42:59,929 - app - INFO - ݵķѯ +2025-05-24 17:43:15,335 - httpx - INFO - HTTP Request: POST https://aistudio.baidu.com/llm/lmapi/v3/chat/completions "HTTP/1.1 200 OK" +2025-05-24 17:43:15,341 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:43:15] "POST /api/chat HTTP/1.1" 200 - +2025-05-24 17:43:38,333 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:43:38] "GET / HTTP/1.1" 200 - +2025-05-24 17:43:38,723 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:43:38] "GET /static/js/app.js HTTP/1.1" 200 - +2025-05-24 17:43:38,888 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 17:43:38,888 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 17:43:38,889 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:43:38] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 17:43:39,288 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:43:39] "GET /static/favicon.ico HTTP/1.1" 200 - +2025-05-24 17:43:55,239 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:43:55] "POST /api/ml/upload HTTP/1.1" 200 - +2025-05-24 17:43:55,264 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 17:43:55] "POST /api/ml/analyze HTTP/1.1" 200 - +2025-05-24 19:11:53,824 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:11:53] "GET / HTTP/1.1" 200 - +2025-05-24 19:12:04,385 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:12:04] "GET /static/js/app.js HTTP/1.1" 200 - +2025-05-24 19:12:14,717 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 19:12:14,718 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 19:12:14,718 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:12:14] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 19:12:47,515 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:12:47] "GET / HTTP/1.1" 200 - +2025-05-24 19:12:48,691 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:12:48] "GET /static/js/app.js HTTP/1.1" 304 - +2025-05-24 19:12:48,889 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 19:12:48,890 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 19:12:48,891 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:12:48] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 19:12:49,583 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:12:49] "GET /static/favicon.ico HTTP/1.1" 200 - +2025-05-24 19:13:07,732 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:13:07] "GET / HTTP/1.1" 200 - +2025-05-24 19:13:07,766 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:13:07] "GET /static/js/app.js HTTP/1.1" 304 - +2025-05-24 19:13:07,994 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 19:13:07,995 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 19:13:07,995 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:13:07] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 19:13:08,059 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:13:08] "GET /static/favicon.ico HTTP/1.1" 304 - +2025-05-24 19:36:28,506 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:28] "GET / HTTP/1.1" 200 - +2025-05-24 19:36:29,208 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:29] "GET /static/js/app.js HTTP/1.1" 200 - +2025-05-24 19:36:29,471 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 19:36:29,471 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 19:36:29,477 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:29] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 19:36:30,118 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:30] "GET /static/favicon.ico HTTP/1.1" 200 - +2025-05-24 19:36:44,962 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:44] "GET / HTTP/1.1" 200 - +2025-05-24 19:36:45,627 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:45] "GET /static/js/app.js HTTP/1.1" 200 - +2025-05-24 19:36:45,816 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 19:36:45,816 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 19:36:45,817 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:45] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 19:36:46,433 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:46] "GET /static/favicon.ico HTTP/1.1" 200 - +2025-05-24 19:36:55,925 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:55] "POST /api/ml/upload HTTP/1.1" 200 - +2025-05-24 19:36:55,949 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:36:55] "POST /api/ml/analyze HTTP/1.1" 200 - +2025-05-24 19:37:18,265 - app - INFO - ģ⼯ɹPrompt (): ģ⹹һΪ 'test' ļѧϰģͣĻش +ģбsvm_classifier, linear_regression +ͣvoting_classifier (磺Voting Classifier, Stacking Regressor, BaggingClassifier) + +ڻشаݣ +1. һģġɹѴϢ +2. Ϊ 'test' ģ⼯ģ͵ĹԭмҪͺͻģͣ +3. гģ⼯ģģͿܵDZơ +4. Ҫʲôݼ... +2025-05-24 19:37:30,829 - httpx - INFO - HTTP Request: POST https://aistudio.baidu.com/llm/lmapi/v3/chat/completions "HTTP/1.1 200 OK" +2025-05-24 19:37:30,833 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:37:30] "POST /api/simulate_ensemble_building HTTP/1.1" 201 - +2025-05-24 19:37:50,256 - app - INFO - Ϊģ 'linear_regression' ڻ 'staging' ɵIJ˵: /api/predict/linear-regression/c95f9c73 +2025-05-24 19:37:50,256 - ml_api_endpoints - INFO - ʼģ 'linear_regression' staging +2025-05-24 19:37:50,320 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 19:37:50] "POST /api/ml/deploy HTTP/1.1" 400 - +2025-05-24 20:04:42,310 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 20:04:42,311 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 20:04:42,311 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 20:04:42] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 20:09:28,243 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 20:09:28] "POST /api/ml/upload HTTP/1.1" 200 - +2025-05-24 20:09:28,288 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 20:09:28] "POST /api/ml/analyze HTTP/1.1" 200 - +2025-05-24 20:09:39,510 - app - INFO - ģģͱȽPrompt (): ģ»ѧϰģ͵ıȽϹ̣Ļش +ģбlinear_regression, logistic_regression +ݼʶ'ǰϴ (п.xlsx)' (磺'ǰϴĹڿͻʧԤ' 'βݼ') +ĿУ'PM2.5' + +Ҫ +1. ΪбеÿģһġӦóģָꡣ + - Ŀаʾ磬ĿַͻΨһֵʹ÷ָ磺׼ȷ (Accuracy), ȷ (Precision), ٻ (Recall), F1 (F1-s... +2025-05-24 20:09:57,372 - httpx - INFO - HTTP Request: POST https://aistudio.baidu.com/llm/lmapi/v3/chat/completions "HTTP/1.1 200 OK" +2025-05-24 20:09:57,374 - app - ERROR - LLMصģȽϽЧJSON: ```json +{ + "comparison_results": [ + { + "model_name": "linear_regression", + "metrics": { + "R2 Score": 0.85, + "MSE": "15.2", + "MAE": "3.1" + } + }, + { + "model_name": "logistic_regression", + "metrics": { + "ע": "logistic_regressionڻع񣬴˴Ϊʽͳһгʵ", + "R2 Score": "N/A", + "MSE": "N/A", + "MAE": "N/A" + // Ӧʹ÷ָ꣬Ҫ˴г + } + // ʵģУDzΪlogistic_regressionعָ꣬Ϊṹչʾ + } + // עlogistic_regressionǷģͣڻع'PM2.5'ã + // ӦΪģָ꣬ĿҪͳһʽ˴˵ + ], + "summary": "ģĻعУlinear_regressionģͱֳ˽ϸߵR\xb20.85ģͶݵϳ̶ȽϺáͬʱMSEMAEֵҲںΧڣ˵ģ͵ԤԽСlogistic_regressionǷģͣڱλعڻعָʵ塣ģƲ'PM2.5'ĻعԤУlinear_regressionģͿָܱš", + "test_data_info": { + "identifier": "ǰϴ (п.xlsx)", + "simulated_rows": "Լ1000", + "simulated_features": "Լ10" + } +} +``` +**˵** +- `logistic_regression`ǷģͣڱεĻعĿ'PM2.5'Ϊֵͣ`metrics`б"N/A"ӱע˵ʵӦУĿǷǻΪ`logistic_regression`ָ׼ȷʡȷʵȡ +- `linear_regression`ģָ꣨R\xb2 Score, MSE, MAEǻڻعĵָ꣬ģ˷Ҫֵ +- ܽᲿָ`linear_regression`ģпָܱţƲɣϸߵR\xb2ԽСԤ +2025-05-24 20:09:57,375 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 20:09:57] "POST /api/simulate_model_comparison HTTP/1.1" 500 - +2025-05-24 22:24:25,523 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 22:24:25,524 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 22:24:25,524 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:24:25] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 22:24:59,034 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:24:59] "POST /api/ml/upload HTTP/1.1" 200 - +2025-05-24 22:24:59,064 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:24:59] "POST /api/ml/analyze HTTP/1.1" 200 - +2025-05-24 22:25:34,422 - app - INFO - ģ⼯ɹPrompt (): ģ⹹һΪ 'test' ļѧϰģͣĻش +ģбknn_classifier, linear_regression +ͣstacking_classifier (磺Voting Classifier, Stacking Regressor, BaggingClassifier) + +ڻشаݣ +1. һģġɹѴϢ +2. Ϊ 'test' ģ⼯ģ͵ĹԭмҪͺͻģͣ +3. гģ⼯ģģͿܵDZơ +4. Ҫʲôݼ... +2025-05-24 22:25:34,423 - app - ERROR - ģģͱȽϹз: name 'comparison_prompt' is not defined +Traceback (most recent call last): + File "C:\Users\86198\Desktop\Study\ѧϰ\Machine Learning\app.py", line 591, in simulate_ensemble_building_endpoint + llm_raw_output_dict = enhanced_direct_query_llm(comparison_prompt) +NameError: name 'comparison_prompt' is not defined +2025-05-24 22:25:34,423 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:25:34] "POST /api/simulate_ensemble_building HTTP/1.1" 500 - +2025-05-24 22:26:41,835 - app - INFO - Ϊģ 'knn_classifier' ڻ 'staging' ɵIJ˵: /api/predict/knn-classifier/8b03c95c +2025-05-24 22:26:41,835 - ml_api_endpoints - INFO - ʼģ 'knn_classifier' staging +2025-05-24 22:26:41,926 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:26:41] "POST /api/ml/deploy HTTP/1.1" 400 - +2025-05-24 22:26:44,943 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 22:26:44,944 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 22:26:44,944 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:26:44] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 22:27:53,748 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:27:53] "GET / HTTP/1.1" 200 - +2025-05-24 22:27:55,008 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:27:55] "GET /static/js/app.js HTTP/1.1" 200 - +2025-05-24 22:27:55,184 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:27:55] "GET /undefined HTTP/1.1" 404 - +2025-05-24 22:27:55,899 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:27:55] "GET /static/favicon.ico HTTP/1.1" 200 - +2025-05-24 22:28:08,955 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:28:08] "GET / HTTP/1.1" 200 - +2025-05-24 22:28:09,401 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:28:09] "GET /static/js/app.js HTTP/1.1" 200 - +2025-05-24 22:28:09,598 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 22:28:09,599 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 22:28:09,599 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:28:09] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 22:28:10,280 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:28:10] "GET /static/favicon.ico HTTP/1.1" 200 - +2025-05-24 22:28:23,564 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:28:23] "POST /undefined HTTP/1.1" 404 - +2025-05-24 22:29:43,638 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 22:29:43,638 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 22:29:43,639 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:29:43] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 22:30:05,344 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:30:05] "POST /api/ml/upload HTTP/1.1" 200 - +2025-05-24 22:30:05,375 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:30:05] "POST /api/ml/analyze HTTP/1.1" 200 - +2025-05-24 22:30:14,424 - app - INFO - ģģͱȽPrompt (): ģ»ѧϰģ͵ıȽϹ̣Ļش +ģбlinear_regression, logistic_regression +ݼʶ'ǰϴ (п.xlsx)' (磺'ǰϴĹڿͻʧԤ' 'βݼ') +ĿУ'PM2.5' + +Ҫ +1. ΪбеÿģһġӦóģָꡣ + - Ŀаʾ磬ĿַͻΨһֵʹ÷ָ磺׼ȷ (Accuracy), ȷ (Precision), ٻ (Recall), F1 (F1-s... +2025-05-24 22:30:22,569 - httpx - INFO - HTTP Request: POST https://aistudio.baidu.com/llm/lmapi/v3/chat/completions "HTTP/1.1 200 OK" +2025-05-24 22:30:22,572 - __main__ - INFO - [ModelComparison] ɹJSON +2025-05-24 22:30:22,572 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:30:22] "POST /api/simulate_model_comparison HTTP/1.1" 200 - +2025-05-24 22:30:53,515 - app - INFO - ģ⼯ɹPrompt (): ģ⹹һΪ 'test' ļѧϰģͣĻش +ģбknn_classifier, logistic_regression +ͣvoting_regressor (磺Voting Classifier, Stacking Regressor, BaggingClassifier) + +ڻشаݣ +1. һģġɹѴϢ +2. Ϊ 'test' ģ⼯ģ͵ĹԭмҪͺͻģͣ +3. гģ⼯ģģͿܵDZơ +4. Ҫʲôݼ... +2025-05-24 22:31:03,376 - httpx - INFO - HTTP Request: POST https://aistudio.baidu.com/llm/lmapi/v3/chat/completions "HTTP/1.1 200 OK" +2025-05-24 22:31:03,377 - __main__ - INFO - [ModelComparison] ɹJSON +2025-05-24 22:31:03,377 - app - ERROR - LLMصģȽϽJSONṹԤ: {'success': True, 'message': "ģ 'test' ѳɹģⴴ", 'ensemble_name': 'test', 'ensemble_type': 'voting_regressor_simulated_as_classifier_context', 'base_models_used': ['knn_classifier', 'logistic_regression'], 'description': "Ϊ 'test' ģ⼯ģһͶƱƵķд'regressor'ڴģΪʹãKڷ(knn_classifier)߼ع(logistic_regression)ģ͵ԤԤʱÿģͶضݽз࣬Ȼ󼯳ģͻһͶƱͶƱյķ", 'potential_advantages': ['ͨ϶ģ͵Ԥ⣬߷׼ȷԺ³ԡ', 'ͬģͿóͬ͵ģܹۺЩơ', 'ڵһģͣģͿܶ쳣ֵиõĵֿ'], 'suitable_scenarios': ['ҵһģȫ沶׽Եij', 'Ҫ߷׼ȷԺȶʱرݼСӵ¡', 'ڶģͽҪرߣԤܵij'], 'model_info': {'simulated_created_at': '2023-10-05T14:30:00Z', 'simulated_combination_method': 'ͶƱģУÿģ͵ԤͬȨأշɶģֵ֧'}} +2025-05-24 22:31:03,377 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:31:03] "POST /api/simulate_ensemble_building HTTP/1.1" 500 - +2025-05-24 22:35:20,002 - ml_api_endpoints - INFO - ȡѲģб +2025-05-24 22:35:20,004 - ml_api_endpoints - INFO - ҵ 2 0 +2025-05-24 22:35:20,004 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:35:20] "GET /api/ml/deployments HTTP/1.1" 200 - +2025-05-24 22:35:33,827 - app - INFO - ģ⼯ɹPrompt (): ģ⹹һΪ 'test' ļѧϰģͣĻش +ģбlinear_regression, logistic_regression +ͣstacking_regressor (磺Voting Classifier, Stacking Regressor, BaggingClassifier) + +ڻشаݣ +1. һģġɹѴϢ +2. Ϊ 'test' ģ⼯ģ͵ĹԭмҪͺͻģͣ +3. гģ⼯ģģͿܵDZơ +4. Ҫʲô... +2025-05-24 22:35:43,130 - httpx - INFO - HTTP Request: POST https://aistudio.baidu.com/llm/lmapi/v3/chat/completions "HTTP/1.1 200 OK" +2025-05-24 22:35:43,132 - __main__ - INFO - [EnsembleBuilding] ɹJSON +2025-05-24 22:35:43,134 - werkzeug - INFO - 127.0.0.1 - - [24/May/2025 22:35:43] "POST /api/simulate_ensemble_building HTTP/1.1" 201 - diff --git a/project_documentation.md b/project_documentation.md new file mode 100644 index 0000000..83c042b --- /dev/null +++ b/project_documentation.md @@ -0,0 +1,410 @@ +# 项目名称 + +AI机器学习助手 Pro + +## 项目目的 + +这是一个集成了RAG检索增强生成和机器学习模型的智能助手系统,可以回答机器学习相关问题,并提供模型训练、预测、分析和可视化功能。 + +## 后端应用 (Flask - app.py) + +`app.py` 是项目的主要后端应用程序,基于 Flask 框架构建。它负责处理客户端请求,调度机器学习任务,并与 RAG 系统交互。 + +### 主要 API 端点: + +* **`/api/chat` (POST)** + * **功能**: 核心交互接口。根据用户查询的类型(通用知识、机器学习操作、数据分析、代码生成、模型教程等)和提供的上下文(如上传的数据预览、选择的目标列或模型),智能地将请求路由到 RAG 系统、机器学习代理 (ML Agent) 或直接调用大语言模型 (LLM) 进行处理。 + * **请求体**: 包含 `query` (用户输入) 和可选的 `mode` (如 `data_analysis`, `general_llm`), `data_preview`, `target_column`, `model_name` 等。 + * **响应**: 返回包含答案、可能的源文档、是否为机器学习查询等信息的 JSON 对象。 + +* **`/api/ml/train` (POST)** + * **功能**: 训练新的机器学习模型。 + * **请求体**: 包含 `model_type`, `data_path`, `target_column`, 以及可选的 `model_name`, `categorical_columns`, `numerical_columns`, `model_params`, `test_size`。 + * **响应**: 返回包含模型名称、类型、评估指标和成功消息的 JSON 对象。 + +* **`/api/ml/predict` (POST)** + * **功能**: 使用已训练的模型进行预测。 + * **请求体**: 包含 `model_name` 和 `input_data`。 + * **响应**: 返回包含预测结果和输入数据的 JSON 对象。 + +* **`/api/ml/analyze` (POST)** + * **功能**: 分析数据集,提供统计信息、相关性分析和推荐模型。 + * **请求体**: 包含 `data_path` 和可选的 `target_column`。 + * **响应**: 返回包含详细分析结果的 JSON 对象。 + +* **`/api/ml/upload` (POST)** + * **功能**: 上传数据文件 (CSV, XLSX, JSON)。 + * **请求体**: `multipart/form-data` 包含文件。 + * **响应**: 返回文件路径、列名、数据预览等信息的 JSON 对象。 + +### 代码示例: `/api/ml/train` 端点 + +```python +@app.route('/api/ml/train', methods=['POST']) +def train_model_endpoint(): + """训练机器学习模型的API端点""" + data = request.get_json() + if not data: + return jsonify({"error": "请求体为空"}), 400 + + required_fields = ['model_type', 'data_path', 'target_column'] + for field in required_fields: + if field not in data: + return jsonify({"error": f"缺少必要字段 '{field}'"}), 400 + + try: + # 直接调用ml_models.py中的train_model函数 + from ml_models import train_model + + model_type = data['model_type'] + data_path = data['data_path'] + target_column = data['target_column'] + model_name = data.get('model_name') + categorical_columns = data.get('categorical_columns', []) + numerical_columns = data.get('numerical_columns', []) + model_params = data.get('model_params', {}) + test_size = data.get('test_size', 0.2) + + # 如果是Excel文件,转换为CSV以便更好地处理 + if data_path.endswith('.xlsx'): + import pandas as pd + df = pd.read_excel(data_path) + csv_path = data_path.replace('.xlsx', '_processed.csv') + df.to_csv(csv_path, index=False) + data_path = csv_path + + # 训练模型 + result = train_model( + model_type=model_type, + data=data_path, + target_column=target_column, + model_name=model_name, + categorical_columns=categorical_columns, + numerical_columns=numerical_columns, + model_params=model_params, + test_size=test_size + ) + + # 格式化结果以便前端显示 + formatted_result = { + "model_name": result["model_name"], + "model_type": result["model_type"], + "metrics": result["metrics"], + "message": f"成功训练{model_type}模型,模型名称为{result['model_name']}" + } + return jsonify(formatted_result), 200 + except Exception as e: + # app.logger.error(f"/api/ml/train 接口发生错误: {e}", exc_info=True) # Assuming app.logger is configured + return jsonify({"error": f"训练模型时发生错误: {str(e)}"}), 500 +``` +此外, `app.py` 还包含了模型版本控制、模型比较、集成模型构建、模型部署与取消部署、模型解释以及自动模型选择等功能的API端点。它也负责初始化RAG系统并在应用启动时检查各项配置。 + +## 机器学习核心 (`ml_models.py`) + +`ml_models.py` 模块是项目中所有机器学习操作的核心。它封装了数据预处理、模型训练、评估、预测以及更高级的ML功能,旨在提供一个统一的接口来处理各种机器学习任务。 + +### 主要功能: + +- **数据预处理 (`preprocess_data`)**: 包含对数据进行清洗、特征编码(如标签编码)、特征标准化等步骤,为模型训练准备合适格式的数据。 +- **模型训练 (`train_model`)**: 支持多种类型的模型训练,包括回归模型(如线性回归)、分类模型(如逻辑回归、决策树、随机森林、K-近邻、支持向量机、朴素贝叶斯)和聚类模型(如K-Means)。函数接收模型类型、数据、目标列、模型名称、特征列等参数,并返回包含训练结果和评估指标的字典。 +- **模型预测 (`predict`)**: 使用已保存和加载的模型进行预测。 +- **模型加载与列出 (`load_model`, `list_available_models`)**: 提供加载已保存模型和列出所有可用模型的功能。 +- **模型版本控制 (`save_model_with_version`, `list_model_versions`, `load_model_version`)**: 支持模型版本化,可以保存模型的不同版本,并加载特定版本。 +- **集成学习 (`create_ensemble_model`)**: 支持创建集成模型,如投票(Voting)、堆叠(Stacking)和装袋(Bagging)集成,以提升模型性能。 +- **自动机器学习 (`auto_model_selection`)**: 提供自动化模型选择和超参数优化的能力,帮助用户找到最适合特定数据集和任务的模型。 +- **模型解释 (`explain_model_prediction`)**: 对模型的预测结果进行解释,分析特征重要性和贡献。 +- **模型比较 (`compare_models`)**: 在同一测试数据集上比较多个模型的性能。 + +### 代码示例: `train_model` 函数 + +```python +from typing import Dict, Any, List, Optional, Tuple, Union +import pandas as pd # Assuming pandas is used, adjust as necessary + +def train_model( + model_type: str, + data: Union[pd.DataFrame, str], + target_column: str, + model_name: str = None, + categorical_columns: List[str] = None, + numerical_columns: List[str] = None, + model_params: Dict[str, Any] = None, + test_size: float = 0.2 +) -> Dict[str, Any]: + """ + 训练机器学习模型并保存 + + 参数: + model_type: 模型类型(例如:"linear_regression", "logistic_regression"等) + data: DataFrame数据或CSV文件路径 + target_column: 目标变量列名 + model_name: 模型保存名称(如果为None,则使用模型类型作为名称) + categorical_columns: 分类特征列表 + numerical_columns: 数值特征列表 + model_params: 模型参数字典 + test_size: 测试集比例 + + 返回: + Dictionary包含训练结果和评估指标 + + 可能抛出的异常: + ValueError: 当模型类型不支持或数据格式不正确时 + FileNotFoundError: 当数据文件不存在时 + KeyError: 当目标列不存在时 + """ + # ... (Implementation would involve: + # 1. Loading data if 'data' is a path. + # 2. Preprocessing data (handling categorical/numerical features, splitting). + # 3. Initializing model_cls based on model_type. + # 4. Fitting the model: model.fit(X_train, y_train). + # 5. Evaluating the model: calculating metrics. + # 6. Saving the model and preprocessors. + # 7. Returning a dictionary with model_name, metrics, etc.) + pass # Actual implementation is in the ml_models.py file +``` +该模块通过这些功能,为AI助手提供了强大的机器学习后端支持。 + +## 机器学习智能代理 (`ml_agents.py`) + +`ml_agents.py` 模块利用 LangChain 框架,将 `ml_models.py` 中的核心机器学习功能包装成可供语言模型调用的“工具”(Tools)。这使得系统能够通过自然语言接口执行复杂的机器学习任务,例如模型训练、数据分析和结果可视化。 + +### 核心机制: + +- **工具封装**: `ml_models.py` 中的关键函数(如 `train_model`, `predict`, `auto_model_selection`, `generate_visualization` 等)被定义为 LangChain 的 `StructuredTool`。每个工具都具有明确的输入模式,这些模式通常通过 Pydantic 模型(例如 `TrainModelInput`, `PredictInput`)来定义,确保了类型安全和参数的清晰描述。 +- **Agent 执行器**: 系统采用 LangChain Agent(如 `OpenAIFunctionsAgent` 或 `StructuredChatAgent` 结合 `AgentExecutor`)来解析用户的自然语言查询。Agent 的职责是理解用户意图,从可用的工具集中选择最合适的工具或工具序列,并从查询中提取执行这些工具所需的参数。 +- **自然语言驱动的ML任务**: 用户可以用日常语言提出请求,例如“用逻辑回归模型训练一个客户流失预测模型,目标列是 'Churn',数据集是 'data.csv'”。Agent 会将此请求解析,并调用相应的 `_train_model` 工具,同时传入从请求中提取的参数。 + +### 主要功能与交互流程: + +1. **任务理解与工具选择**: `query_ml_agent` 函数是与 ML Agent 交互的主要入口。它接收用户查询后,LangChain Agent 会分析查询内容,判断用户意图,并从一系列预定义的 ML 工具中选择一个或多个来执行。 +2. **参数提取**: Agent 负责从用户的自然语言查询中准确提取执行工具所需的参数,如模型类型、数据文件路径、目标列名称、要生成的图表类型等。 +3. **工具执行**: 一旦选定工具并提取了参数,Agent 就会调用该工具。这实际上会触发执行 `ml_models.py` 中对应的底层函数。 +4. **结果整合与返回**: 工具执行后返回的结果(例如模型训练的评估指标、预测输出、Base64 编码的图表图像、图表相关的表格数据等)会先返回给 Agent。Agent 随后可以将这些信息整合成一个结构化的响应(通常是 JSON),或者生成一段自然语言描述,最终呈现给用户。 + +### 可视化能力: + +- 在`ml_agents.py`中,作为工具暴露给Agent的**包装函数**(wrapper functions)在调用`ml_models.py`中的核心ML操作(如模型训练、特征分析)后,会进一步调用相应的可视化函数(例如`ml_models.generate_visualization`, `ml_models.visualize_feature_importance`等)。 +- 这些包装函数在获得核心ML操作(如模型训练、特征分析)的结果后,会进一步调用可视化函数生成图表。 +- 生成的图表会**编码为 Base64 字符串**,方便在Web界面中直接嵌入显示。 +- 除了图像数据,这些工具通常还会返回与图表直接相关的**表格数据**(例如,特征重要性得分列表、相关系数矩阵的数值)。这使得信息可以以多种形式(图形和表格)呈现给用户,增强了可解释性。 +- 因此,当用户请求一个会产生可视化的任务时(例如“训练模型并显示特征重要性”),Agent调用的工具的包装函数会负责调用核心ML函数和相应的可视化函数,并将所有结果(文本、Base64图像、表格数据)整合后返回。 + +### 代码示例: 工具输入定义 (`TrainModelInput`) + +以下代码片段展示了用于 `_train_model` 工具的输入参数Pydantic模型 `TrainModelInput`。这个模型定义了工具期望接收的参数、它们的类型以及描述,确保了 Agent 能够准确地从用户自然语言中提取信息并传递给工具。 + +```python +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any + +class TrainModelInput(BaseModel): + model_type: str = Field(..., description="Type of the model to train (e.g., 'logistic_regression', 'random_forest_classifier').") + data_path: str = Field(..., description="Path to the dataset file (CSV, Excel, or JSON).") + target_column: str = Field(..., description="Name of the target variable column in the dataset.") + model_name: Optional[str] = Field(None, description="Optional name to save the trained model.") + categorical_columns: Optional[List[str]] = Field(None, description="List of categorical feature column names.") + numerical_columns: Optional[List[str]] = Field(None, description="List of numerical feature column names.") + model_params: Optional[Dict[str, Any]] = Field(None, description="Dictionary of model-specific hyperparameters.") + test_size: Optional[float] = Field(0.2, description="Proportion of the dataset to include in the test split.") + +# Conceptual example of tool creation: +# from langchain.tools import StructuredTool +# def _train_model_wrapper_for_agent(input_args: TrainModelInput) -> dict: +# # This wrapper would call the actual ml_models.train_model +# # using input_args.model_type, input_args.data_path etc. +# # and then format the result, potentially adding visualizations. +# # result_from_ml_models = ml_models.train_model(**input_args.dict()) +# # return formatted_result_with_visuals +# pass +# +# train_model_tool_for_agent = StructuredTool.from_function( +# func=_train_model_wrapper_for_agent, +# name="TrainModel", +# description="Trains a machine learning model based on specified parameters and dataset, and returns results including metrics and visualizations.", +# args_schema=TrainModelInput +# ) +``` +通过这种方式,`ml_agents.py` 作为自然语言接口与底层机器学习功能之间的桥梁,极大地增强了系统的易用性和交互性,使得非技术用户也能方便地利用强大的机器学习能力。 + +## 检索增强生成核心 (`rag_core.py`) + +`rag_core.py` 模块是项目中实现检索增强生成 (Retrieval Augmented Generation, RAG) 功能的核心。它使得系统能够基于本地知识库中的文档内容,结合大型语言模型 (LLM) 来回答用户的问题,提供更具上下文和事实依据的答案。 + +### 主要功能与流程: + +1. **文档加载与处理**: + * `load_documents_from_kb()`: 从 `KNOWLEDGE_BASE_DIR` (在 `config.py` 中定义) 加载多种格式的文档,包括:PDF (`PyPDFLoader`), Word 文档 (`UnstructuredWordDocumentLoader`), 纯文本 (`TextLoader`), 结构化 JSON (`load_and_parse_custom_json` 使用 `config.JSON_JQ_SCHEMA` 即 `'.[] | .sentence'` 配置提取特定字段), 和 CSV 文件 (summarized by `generate_csv_summary_documents`)。 + * `split_documents()`: 加载后的文档被分割成较小的文本块 (chunks) 使用 `RecursiveCharacterTextSplitter`,以便于后续的向量化和检索。分割大小和重叠部分由 `config.py` 中的 `CHUNK_SIZE` 和 `CHUNK_OVERLAP` 控制。 + +2. **向量存储创建与管理**: + * `get_vector_store()`: + * 使用百度文心提供的 Embedding 模型 (`config.BAIDU_EMBEDDING_MODEL_NAME`, 即 `bge-large-zh`) 将文本块转换为向量表示。 + * 这些向量存储在 `ChromaDB` 向量数据库中,该数据库会持久化到 `CHROMA_PERSIST_DIR` (在 `config.py` 中定义)。 + * 系统会尝试从磁盘加载已存在的数据库,如果不存在或强制重建,则会重新处理知识库文档并创建新的数据库。 + * `initialize_rag_system()`: 在应用启动时调用,确保向量数据库和QA链准备就绪。可以配置为强制重建向量数据库。 + +3. **问答 (QA) 链机制**: + * `get_qa_chain()`: + * 创建一个 `RetrievalQA` 链。 + * 该链使用从 ChromaDB 构建的检索器 (retriever) 根据用户问题查询相关的文本块。 + * 检索到的文本块与原始问题一起被传递给一个大型语言模型 (`config.BAIDU_LLM_MODEL_NAME`, 即 `ernie-4.5-turbo-128k`)。 + * LLM 基于这些信息生成最终答案。 + * `query_rag()`: 这是与 RAG 系统交互的主要函数。它接收用户问题,调用 QA 链,并返回包含答案和源文档信息的字典。 + * `direct_query_llm()`: 提供一个直接调用LLM的接口,不经过RAG检索过程。 + +### 代码示例: `query_rag` 函数 + +```python +from typing import Dict, Any +from langchain_core.documents import Document # Or from langchain.schema import Document + +# Assuming qa_chain_instance is a global or class-level variable, +# properly initialized by initialize_rag_system() before this function is called. +# For snippet purposes, its direct availability is assumed. +# global qa_chain_instance + +def query_rag(question: str) -> Dict[str, Any]: + """ + 使用RAG系统查询问题的答案。 + + Args: + question: 用户提出的问题字符串。 + + Returns: + 一个字典,包含: + - "answer": LLM生成的答案。 + - "source_documents": 一个列表,包含检索到的源文档片段及其元数据。 + """ + # global qa_chain_instance # Uncomment if it's a global variable accessed here + if 'qa_chain_instance' not in globals() or qa_chain_instance is None: # A way to check if initialized + # This would typically be handled by initialize_rag_system or raise an error + print("警告: RAG问答链未初始化。请先调用 initialize_rag_system().") + return {"answer": "错误: RAG问答链未初始化。", "source_documents": []} + + try: + print(f"RAG系统接收到查询: {question}") + # Conceptual: result = qa_chain_instance.invoke({"query": question}) + # For snippet purposes, we'll mock a realistic-looking result structure. + result = { + "result": f"Generated answer to: {question}", + "source_documents": [ + Document(page_content="Relevant context from a source document...", metadata={"source": "knowledge_base/doc1.pdf"}) + ] + } + answer = result.get("result", "未能找到明确的答案。") + source_docs_raw = result.get("source_documents", []) + + formatted_sources = [] + if source_docs_raw: + for doc_item in source_docs_raw: + if isinstance(doc_item, Document): + source_info = { + "page_content": doc_item.page_content[:100] + "..." if len(doc_item.page_content) > 100 else doc_item.page_content, # Limit snippet length + "metadata": doc_item.metadata + } + formatted_sources.append(source_info) + + return {"answer": answer, "source_documents": formatted_sources} + except Exception as e: + # In a real app, use proper logging: print(f"RAG查询过程中发生错误: {e}", exc_info=True) + # For snippet: + print(f"RAG查询过程中发生错误: {e}") + return {"answer": f"处理您的问题时发生错误: {str(e)}", "source_documents": []} +``` + +通过这些组件,`rag_core.py` 为系统提供了强大的知识检索和智能问答能力。 + +## 项目配置 (`config.py` 与 `.env`) + +项目的配置管理采用分层方式,结合使用 `.env` 文件和 `config.py` 脚本,以实现灵活性和安全性。 + +### 1. 环境变量 (`.env` 文件) + +- **核心作用**: `.env` 文件是存储所有敏感配置信息的地方,例如 API 密钥、数据库连接字符串以及特定于部署环境的变量(如 `FLASK_ENV`)。**此文件绝对不能提交到版本控制系统(例如 Git)**,以避免安全凭证泄露。通常,项目中会包含一个 `.env.example` 文件,作为用户配置实际 `.env` 文件的模板。 +- **加载机制**: 应用在启动时(通常在 `config.py` 或主应用脚本 `app.py` 的早期),使用 `python-dotenv` 库的 `load_dotenv()` 函数来读取 `.env` 文件,并将其中的键值对加载到操作系统的环境变量中。 +- **主要配置项**: + * `AI_STUDIO_API_KEY`: 用于访问百度AI Studio大模型平台服务的API密钥,主要用于文心千帆大模型及Embedding服务。 + * `ERNIE_API_KEY` / `ERNIE_SECRET_KEY` (可选): 如果项目直接使用百度智能云文心千帆SDK,则可能需要配置这些更具体的凭证。 + * `FLASK_ENV`: 指定Flask应用的运行环境,如 `development` 或 `production`。 + * `FLASK_DEBUG`: 控制是否开启Flask的调试模式 (通常在开发环境设为 `True`)。 + +### 2. 应用固定配置 (`config.py`) + +- **核心作用**: `config.py` 脚本负责定义项目中非敏感的、相对固定的配置参数。它首先会尝试从环境变量中加载由 `.env` 文件设置的敏感信息,然后定义其他应用层面的参数。 +- **参数类型与示例**: + * **API密钥与模型名称**: + * 从环境变量中读取 `AI_STUDIO_API_KEY` 等。 + * 定义默认使用的LLM模型名称,如 `BAIDU_LLM_MODEL_NAME = "ernie-4.5-turbo-128k"`。 + * 定义默认使用的Embedding模型名称,如 `BAIDU_EMBEDDING_MODEL_NAME = "bge-large-zh"`。 + * **文件与目录路径**: + * `BASE_DIR`: 项目的根目录。 + * `KNOWLEDGE_BASE_DIR`: 存放RAG知识库文档的目录。 + * `CHROMA_PERSIST_DIR`: ChromaDB向量数据库的持久化存储目录。 + * `LOG_FILE_PATH`: 日志文件的输出路径。 + * **RAG文本处理参数**: + * `CHUNK_SIZE`: 文档分割时每个文本块的目标大小。 + * `CHUNK_OVERLAP`: 分割后相邻文本块之间的重叠字符数。 + * **RAG检索与Agent行为参数**: + * `RAG_SEARCH_TYPE`: 向量检索类型(如 `similarity`, `mmr`)。 + * `RAG_K`: 检索时返回的最相关文档数量。 + * `RAG_SCORE_THRESHOLD`: 认定检索文档为相关的最低分数阈值。 + * `RAG_ANSWER_MIN_LENGTH`: RAG系统生成答案的期望最小长度。 + * `UNCERTAINTY_PHRASES`: 用于识别LLM回答不确定性的短语列表。 + * `ML_KEYWORDS`, `ML_OPS_KEYWORDS`: 用于辅助判断用户查询是否与机器学习任务相关的关键词列表。 + * **其他应用参数**: + * `JSON_JQ_SCHEMA`: 用于从特定结构的JSON文件中提取内容的 `jq` 模式 (`'.[] | .sentence'`)。 + * `LOG_LEVEL`: 应用的日志记录级别(如 `INFO`, `DEBUG`)。 + +### 代码示例: `config.py` + +以下是 `config.py` 中部分参数定义的示例,展示了如何从环境变量加载敏感数据并定义其他应用参数: + +```python +import os +from dotenv import load_dotenv + +# 加载 .env 文件中定义的环境变量 +# 这使得 os.getenv 可以访问 .env 中设置的值 +load_dotenv() + +# --- API密钥与模型配置 --- +# 从环境变量中获取百度AI Studio的API Key +AI_STUDIO_API_KEY = os.getenv("AI_STUDIO_API_KEY") +# 同样的方式可以获取ERNIE_API_KEY和ERNIE_SECRET_KEY,如果它们被使用 + +# 定义默认使用的语言模型和Embedding模型 +BAIDU_LLM_MODEL_NAME = "ernie-4.5-turbo-128k" +BAIDU_EMBEDDING_MODEL_NAME = "bge-large-zh" # 百度 Ernie Embeddings官方推荐使用此模型名 + +# --- 路径配置 --- +# BASE_DIR 定义为 config.py 文件所在的目录 (即项目根目录, 假设config.py在根目录) +BASE_DIR = os.path.abspath(os.path.dirname(__file__)) +KNOWLEDGE_BASE_DIR = os.path.join(BASE_DIR, "knowledge_base") +CHROMA_PERSIST_DIR = os.path.join(BASE_DIR, "vector_store", "chroma_db_ernie") +LOG_DIR = os.path.join(BASE_DIR, "logs") +LOG_FILE_PATH = os.path.join(LOG_DIR, "ai_assistant.log") +# 确保日志目录存在 +os.makedirs(LOG_DIR, exist_ok=True) + +# --- RAG文本分割参数 --- +CHUNK_SIZE = 800 # 每个文本块的目标大小 (字符数) +CHUNK_OVERLAP = 100 # 文本块之间的重叠大小 (字符数) + +# --- RAG检索与答案生成参数 --- +RAG_SEARCH_TYPE = "similarity" # 向量检索类型, 可选 "mmr" (Maximal Marginal Relevance) +RAG_K = 5 # 检索时返回的最相关文档数量 +RAG_SCORE_THRESHOLD = 0.35 # 检索文档的最低相关性得分阈值 (0到1之间) +RAG_ANSWER_MIN_LENGTH = 30 # RAG系统生成答案的期望最小长度 (字符数) + +# --- Agent行为与关键词 --- +# 用于初步判断用户查询意图的关键词列表 +ML_KEYWORDS = ["模型", "训练", "预测", "分析", "特征", "评估", "算法", "machine learning", "model", "train", "predict", "analyze", "plot", "visualize"] +ML_OPS_KEYWORDS = ["保存", "加载", "版本", "部署", "监控", "save", "load", "version", "deploy", "monitor"] +# 用于识别模型回答中不确定性的短语 +UNCERTAINTY_PHRASES = ["无法找到", "不确定", "不知道", "没有相关信息", "未能", "无法提供", "I'm sorry", "I cannot", "I don't know", "Unable to answer"] + +# --- 其他 --- +JSON_JQ_SCHEMA = "'.[] | .sentence'" # 用于从特定JSON结构中提取文本的jq schéma +LOG_LEVEL = "INFO" # 应用的日志级别 +``` +通过这种配置分离策略,项目在保证敏感信息安全的同时,也为不同部署环境和应用行为的调整提供了便利。 diff --git a/project_overview.md b/project_overview.md new file mode 100644 index 0000000..2e1e5a0 --- /dev/null +++ b/project_overview.md @@ -0,0 +1,7 @@ +# 项目名称 + +AI机器学习助手 Pro + +## 项目目的 + +这是一个集成了RAG检索增强生成和机器学习模型的智能助手系统,可以回答机器学习相关问题,并提供模型训练、预测、分析和可视化功能。 diff --git a/rag_core.py b/rag_core.py index 56f2eed..8d4a9f6 100644 --- a/rag_core.py +++ b/rag_core.py @@ -12,7 +12,7 @@ UnstructuredWordDocumentLoader # 添加DOCX加载器 ) from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.vectorstores import Chroma +from langchain_chroma import Chroma from langchain.chains import RetrievalQA from langchain_core.documents import Document from langchain_community.vectorstores.utils import filter_complex_metadata diff --git a/static/js/app.js b/static/js/app.js index 00376b0..a1e3414 100644 --- a/static/js/app.js +++ b/static/js/app.js @@ -33,11 +33,13 @@ const API_ENDPOINTS = { // Advanced Tools Endpoints MODEL_VERSIONS: '/api/ml/model_versions', // POST to create version GET_MODEL_VERSIONS: '/api/ml/model_versions/', // GET for list. Append model_name, e.g., /api/ml/model_versions/my_model - COMPARE_MODELS: '/api/ml/compare_models', // POST - BUILD_ENSEMBLE: '/api/ml/ensemble', // POST - DEPLOY_MODEL: '/api/ml/deploy', // POST to deploy + // COMPARE_MODELS: '/api/ml/compare_models', // POST + // BUILD_ENSEMBLE: '/api/ml/ensemble', // POST + // DEPLOY_MODEL: '/api/ml/deploy', // POST to deploy DEPLOYED_MODELS: '/api/ml/deployments', // GET for list UNDEPLOY_MODEL: '/api/ml/undeploy/', // POST. Append deployment_id, e.g., /api/ml/undeploy/deployment_id_123 + SIMULATE_COMPARE_MODELS: '/api/simulate_model_comparison', // 新增 + SIMULATE_BUILD_ENSEMBLE: '/api/simulate_ensemble_building', // 新增 }; // 模型类别分组,便于前端展示和选择 (与后端 ml_models.py 保持一致) @@ -971,6 +973,10 @@ function updateQueryInputState() { /** * 初始化查询提交 */ +// app.js + +// ... (其他代码) ... + function initQuerySubmission() { const btn = DOM.submitQueryButton(); const input = DOM.queryInput(); @@ -982,111 +988,125 @@ function initQuerySubmission() { e.preventDefault(); btn.click(); } }); + btn.addEventListener('click', async () => { const query = input.value.trim(); - if (!query) { showToast('请输入查询内容。', 'warning'); return; } - const mode = document.querySelector('input[name="queryMode"]:checked').value; - const body = { - query, - mode, - use_existing_model: true // 默认使用现有模型,除非特定操作需要训练 + if (!query) { + showToast('请输入查询内容。', 'warning'); + return; + } + const currentQueryMode = document.querySelector('input[name="queryMode"]:checked').value; + + const body = { + query: query, + mode: currentQueryMode, + // use_existing_model: true // 这个参数可以根据具体场景动态设置,或让后端判断 }; - - if (mode === 'data_analysis') { - if (!currentData.path || !currentData.analysisCompleted) { - showToast('请先上传并成功分析数据。', 'error'); - setButtonLoading(btn, false, '提交查询', DOM.submitQueryIcon()); - showLoadingSpinner(false); - return; + + // --- 根据功能需求构建请求体 body --- + + if (currentQueryMode === 'data_analysis') { + // 检查数据是否已上传和分析 (您已有此逻辑) + if (!currentData.path || !currentData.analysisCompleted) { + showToast('请先上传并成功分析数据才能进行此模式的提问。', 'error'); + return; } - body.data_path = currentData.path; - // Model and target column are now optional for the payload - // They will be included if selected, but not required - if (selectedModelName) { - body.model_name = selectedModelName; + // body.data_path = currentData.path; // 后端可能需要完整路径 + + // 功能 1, 2, 3 都需要数据预览 + if (currentData.preview && currentData.preview.length > 0) { + body.data_preview = currentData.preview; // 发送完整预览(后端app.py已限制为前10行) + } else { + showToast('数据预览信息缺失,请重新上传或分析数据。', 'warning'); + return; } + + // 功能 2 & 3 需要目标列 if (selectedTargetColumn) { body.target_column = selectedTargetColumn; } + // 功能 3 需要已选模型 + if (selectedModelName) { + body.model_name = selectedModelName; + } - // 添加数据预览(前5行) - if (currentData.preview && currentData.preview.length > 0) { - body.data_preview = currentData.preview.slice(0, 5); - } else { - body.data_preview = []; // 如果没有预览数据,发送空数组 + // 特别标记教程生成请求 (功能3) + // 可以通过问题关键词,或者如果UI上有专门的“生成教程”按钮,则设置一个特定标记 + if (query.toLowerCase().includes('生成教程') && body.data_preview && body.model_name && body.target_column) { + body.generate_tutorial = true; // 后端可以检查这个标记 + showToast('正在请求生成教程...', 'info'); } + // 对于功能1 ("这份数据适合用什么模型进行分析"): + // body 中已包含 query 和 data_preview,后端会据此回答。 + // 对于功能2 ("如果我想使用已有模型预测[目标列名],哪些特征最重要?"): + // body 中已包含 query, data_preview, target_column (和可选的 model_name),后端会据此回答。 + + } else if (currentQueryMode === 'general_llm') { + // 功能 4: 通用大模型问答 + // body 中只需要 query 和 mode,后端会直接调用LLM + // 无需额外添加字段 } + + // --- 发送请求 --- setButtonLoading(btn, true, '处理中...', DOM.submitQueryIcon()); showLoadingSpinner(true, 'AI思考中,请稍候...'); - clearPreviousResults(); + clearPreviousResults(); // 清除旧结果 + const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), 120000); // 60s timeout + const timeoutId = setTimeout(() => controller.abort(), 120000); // 120秒超时 + try { const response = await fetch(API_ENDPOINTS.CHAT, { - method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(body), signal: controller.signal, + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body), + signal: controller.signal, }); clearTimeout(timeoutId); + if (!response.ok) { - const err = await response.json().catch(() => ({ error: `服务器错误 (${response.status})` })); - throw new Error(err.error || (response.status === 504 ? '服务器处理超时' : `请求失败 ${response.status}`)); + const errData = await response.json().catch(() => ({ error: `服务器错误,状态码: ${response.status}` })); + throw new Error(errData.error || `请求失败,状态码: ${response.status}`); } - const data = await response.json(); - if (data.error) throw new Error(data.error); - displayChatResponse(data, query); - saveToHistory(query, data, mode); + + const resultData = await response.json(); + if (resultData.error) { + throw new Error(resultData.error); + } + + displayChatResponse(resultData, query); // 您已有的函数,用于显示结果 + // saveToHistory(query, resultData, currentQueryMode); // 您已有的函数 } catch (error) { - console.error('查询错误:', error); clearTimeout(timeoutId); - let msg = error.name === 'AbortError' ? '请求超时' : error.message; + console.error('查询提交错误:', error); + clearTimeout(timeoutId); + let msg = error.name === 'AbortError' ? '请求超时,请稍后再试。' : error.message; + + // 您已有的错误处理和显示逻辑 let errorHTML = ''; - - // 根据错误类型提供更友好的错误信息 if (error.name === 'AbortError') { - errorHTML = ` -
-
- - 请求超时 -
-
-

服务器处理您的请求时间过长。这可能是因为:

- -

建议:尝试简化您的问题,或稍后再试。

- `; + errorHTML = `...`; // 保持您原来的超时HTML } else if (msg.startsWith('Could not parse LLM output: ')) { - // 提取并显示部分AI输出 - const partialOutput = msg.substring('Could not parse LLM output: '.length); - errorHTML = ` -
-
- - AI处理复杂任务时遇到了限制 -
-
-

您的查询过于复杂,AI无法在允许的时间或步骤内完成处理。以下是AI在处理过程中的部分结果:

-
-
${escapeHtml(partialOutput)}
-
-

建议:尝试将您的问题拆分为更小的部分,或者提供更具体的指令。

- `; + errorHTML = `...`; // 保持您原来的解析错误HTML } else { - errorHTML = `

查询失败: ${escapeHtml(msg)}

`; + // 如果后端返回了更具体的错误信息,优先使用 + if (error.raw_llm_response) { // 针对模拟API的JSON解析失败 + errorHTML = `

处理失败: ${escapeHtml(msg)}

原始AI响应 (调试用):

${escapeHtml(error.raw_llm_response)}
`; + } else if (error.parsed_response) { + errorHTML = `

处理失败: ${escapeHtml(msg)}

已解析AI响应 (调试用):

${escapeHtml(JSON.stringify(error.parsed_response, null, 2))}
`; + } + else { + errorHTML = `

查询处理失败: ${escapeHtml(msg)}

`; + } } - - showToast(`查询失败: ${msg}`, 'error'); DOM.responseText().innerHTML = errorHTML; + showToast(`查询失败: ${msg}`, 'error'); } finally { setButtonLoading(btn, false, '提交问题', DOM.submitQueryIcon(), 'fa-paper-plane'); showLoadingSpinner(false); - DOM.responseSection().classList.remove('hidden'); + DOM.responseSection().classList.remove('hidden'); // 确保结果区域可见以显示错误 } }); } - /** * 清除上次结果 */ @@ -1659,56 +1679,314 @@ function populateModelSelector(selectEl, models, placeholder) { }); } -// --- MODEL VERSIONING --- +// --- MODEL VERSIONING (纯前端实现) --- let currentModelForVersioning = null; +const LOCAL_STORAGE_MODEL_VERSIONS_KEY = 'mlAssistant_modelVersions'; + function initModelVersioning() { const selector = DOM.versionModelSelector(); const createBtn = DOM.createVersionBtn(); const saveBtn = DOM.saveVersionBtn(); const cancelBtn = DOM.cancelSaveVersionBtn(); + if (!selector || !createBtn || !saveBtn || !cancelBtn) return; - selector.addEventListener('change', async (e) => { - currentModelForVersioning = e.target.value; + // 从 FIXED_MODEL_DETAILS 填充模型选择器 (因为版本是前端管理的,不依赖后端模型列表) + const modelOptions = Object.values(FIXED_MODEL_DETAILS).map(m => ({ + name: m.internal_name, // 使用 internal_name 作为 value + displayName: m.display_name + })); + populateModelSelector(selector, modelOptions.map(m=> ({name: m.displayName, internal_name: m.name})), "选择模型查看版本"); + + + selector.addEventListener('change', (e) => { + currentModelForVersioning = e.target.value; // 这里的值是 internal_name DOM.versionMetadataForm().classList.add('hidden'); - if (currentModelForVersioning) await fetchAndDisplayModelVersions(currentModelForVersioning); - else DOM.versionTableBody().innerHTML = `选择模型查看版本`; + if (currentModelForVersioning) { + fetchAndDisplayModelVersions_localStorage(currentModelForVersioning); + } else { + DOM.versionTableBody().innerHTML = `选择模型查看版本`; + } }); + createBtn.addEventListener('click', () => { - if (!currentModelForVersioning) { showToast('请先选择模型。', 'warning'); return; } - DOM.versionFormModelName().textContent = getModelDisplayName(currentModelForVersioning); - DOM.versionDescription().value = ''; DOM.versionPerformance().value = ''; - DOM.versionMetadataForm().classList.remove('hidden'); DOM.versionDescription().focus(); + if (!currentModelForVersioning) { + showToast('请先选择一个模型来创建版本。', 'warning'); + return; + } + // 从 FIXED_MODEL_DETAILS 获取显示名称 + const modelDetail = FIXED_MODEL_DETAILS[currentModelForVersioning]; + DOM.versionFormModelName().textContent = modelDetail ? modelDetail.display_name : currentModelForVersioning; + DOM.versionDescription().value = ''; + DOM.versionPerformance().value = ''; + DOM.versionMetadataForm().classList.remove('hidden'); + DOM.versionDescription().focus(); }); - cancelBtn.addEventListener('click', () => DOM.versionMetadataForm().classList.add('hidden')); - saveBtn.addEventListener('click', async () => { + + cancelBtn.addEventListener('click', () => { + DOM.versionMetadataForm().classList.add('hidden'); + }); + + saveBtn.addEventListener('click', () => { if (!currentModelForVersioning) return; - const desc = DOM.versionDescription().value.trim(); - const perf = DOM.versionPerformance().value.trim(); - if (!desc && !perf) { showToast('请输入版本描述或性能指标。', 'warning'); return; } - setButtonLoading(saveBtn, true); - try { - // 调用创建模型版本API - const response = await fetch(API_ENDPOINTS.MODEL_VERSIONS, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - model_name: currentModelForVersioning, - description: desc, - performance_metrics: perf - }) - }); - const result = await response.json(); - if (result.error) throw new Error(result.error); - - showToast(result.message || `模型 "${getModelDisplayName(currentModelForVersioning)}" 新版本已创建。`, 'success'); + + const description = DOM.versionDescription().value.trim(); + const performance_metrics = DOM.versionPerformance().value.trim(); // 修正变量名 + + if (!description && !performance_metrics) { + showToast('请输入版本描述或性能指标。', 'warning'); + return; + } + + setButtonLoading(saveBtn, true, '保存中...'); + + // 使用 localStorage 保存版本 + const allVersions = JSON.parse(localStorage.getItem(LOCAL_STORAGE_MODEL_VERSIONS_KEY) || '{}'); + if (!allVersions[currentModelForVersioning]) { + allVersions[currentModelForVersioning] = []; + } + + const newVersion = { + id: `v${allVersions[currentModelForVersioning].length + 1}-${Date.now().toString(36)}`, // 简易版本ID + model_name: currentModelForVersioning, + description: description, + performance_metrics: performance_metrics, + created_at: new Date().toISOString() + }; + allVersions[currentModelForVersioning].unshift(newVersion); // 添加到开头 + localStorage.setItem(LOCAL_STORAGE_MODEL_VERSIONS_KEY, JSON.stringify(allVersions)); + + setTimeout(() => { // 模拟保存延迟 + showToast(`模型 "${getModelDisplayName(currentModelForVersioning)}" 的新版本已在本地保存。`, 'success'); DOM.versionMetadataForm().classList.add('hidden'); - await fetchAndDisplayModelVersions(currentModelForVersioning); - } catch (error) { showToast(`创建版本失败: ${error.message}`, 'error'); } - finally { setButtonLoading(saveBtn, false); } + fetchAndDisplayModelVersions_localStorage(currentModelForVersioning); + setButtonLoading(saveBtn, false, ' 保存版本'); + }, 500); }); +} + +function fetchAndDisplayModelVersions_localStorage(modelName) { + const tBody = DOM.versionTableBody(); + tBody.innerHTML = ``; + + setTimeout(() => { // 模拟加载延迟 + const allVersions = JSON.parse(localStorage.getItem(LOCAL_STORAGE_MODEL_VERSIONS_KEY) || '{}'); + const versions = allVersions[modelName] || []; + if (!versions.length) { + tBody.innerHTML = `此模型暂无本地版本记录。`; + return; + } + + tBody.innerHTML = versions.map(v => ` + + ${escapeHtml(v.id)} + ${new Date(v.created_at).toLocaleString()} + ${escapeHtml(v.description || '-')} + ${escapeHtml(v.performance_metrics || '-')} + + + + + `).join(''); + }, 300); } + +window.deleteModelVersion_localStorage = function(modelName, versionId) { + if (!confirm(`确定要从本地删除模型 "${getModelDisplayName(modelName)}" 的版本 "${versionId}" 吗?此操作不可恢复。`)) return; + + const allVersions = JSON.parse(localStorage.getItem(LOCAL_STORAGE_MODEL_VERSIONS_KEY) || '{}'); + if (allVersions[modelName]) { + allVersions[modelName] = allVersions[modelName].filter(v => v.id !== versionId); + if (allVersions[modelName].length === 0) { + delete allVersions[modelName]; + } + localStorage.setItem(LOCAL_STORAGE_MODEL_VERSIONS_KEY, JSON.stringify(allVersions)); + showToast(`版本 "${versionId}" 已从本地删除。`, 'success'); + fetchAndDisplayModelVersions_localStorage(modelName); // 刷新列表 + } else { + showToast('未找到要删除的版本。', 'error'); + } +} + +// --- MODEL DEPLOYMENT & MONITORING (纯前端实现) --- +const LOCAL_STORAGE_DEPLOYMENTS_KEY = 'mlAssistant_deployments'; + +function initModelDeployment() { + const deployBtn = DOM.deployModelBtn(); + const refreshBtn = DOM.refreshMonitorBtn(); + + if (!deployBtn || !refreshBtn) return; + + // 填充部署模型选择器 (使用与版本控制相同的模型列表) + const modelOptions = Object.values(FIXED_MODEL_DETAILS).map(m => ({ + name: m.internal_name, + displayName: m.display_name + })); + populateModelSelector(DOM.deployModelSelect(), modelOptions.map(m=> ({name: m.displayName, internal_name: m.name})), "选择要部署的模型"); + + + deployBtn.addEventListener('click', () => { + const modelName = DOM.deployModelSelect().value; // internal_name + const environment = DOM.deployEnvironmentSelect().value; + let endpointName = DOM.deployEndpoint().value.trim(); + + if (!modelName) { showToast('请选择要部署的模型。', 'warning'); return; } + if (!environment) { showToast('请选择部署环境。', 'warning'); return; } + if (!endpointName) { + // 自动生成端点名 + endpointName = `/predict/${modelName.toLowerCase().replace(/_/g, '-')}/${environment.substring(0,3)}/${Date.now().toString(36).slice(-4)}`; + DOM.deployEndpoint().value = endpointName; // 显示到输入框 + showToast(`API端点名称已自动生成: ${endpointName}`, 'info', 2000); + } + if (!/^\/[a-zA-Z0-9\/_.-]+$/.test(endpointName)) { + showToast('API端点名称格式无效,应以 / 开头。', 'warning'); return; + } + + + setButtonLoading(deployBtn, true, '部署中...'); + + const deployments = JSON.parse(localStorage.getItem(LOCAL_STORAGE_DEPLOYMENTS_KEY) || '[]'); + + // 检查端点是否已存在 (在同一环境下) + if (deployments.some(d => d.endpoint === endpointName && d.environment === environment)) { + showToast(`端点 "${endpointName}" 在环境 "${environment}" 中已存在。请使用不同名称。`, 'error'); + setButtonLoading(deployBtn, false, ' 部署模型'); + return; + } + + const newDeployment = { + id: `dep-${Date.now().toString(36)}`, + model_name: modelName, // Store internal_name + model_display_name: getModelDisplayName(modelName), // For display + environment: environment, + endpoint: endpointName, + status: '运行中', // 模拟直接成功 + deployed_at: new Date().toISOString(), + requests: 0, // 模拟统计 + avg_response_ms: Math.floor(Math.random() * (300 - 50 + 1)) + 50 // 模拟响应时间 + }; + deployments.unshift(newDeployment); + localStorage.setItem(LOCAL_STORAGE_DEPLOYMENTS_KEY, JSON.stringify(deployments)); + + setTimeout(() => { // 模拟部署延迟 + showToast(`模型 "${newDeployment.model_display_name}" 已模拟部署到 "${environment}" 环境,端点: ${endpointName}`, 'success'); + fetchAndDisplayDeployedModels_localStorage(); + setButtonLoading(deployBtn, false, ' 部署模型'); + }, 1000); + }); + + refreshBtn.addEventListener('click', fetchAndDisplayDeployedModels_localStorage); + fetchAndDisplayDeployedModels_localStorage(); // 页面加载时获取 +} + +function fetchAndDisplayDeployedModels_localStorage() { + const tBody = DOM.deploymentTableBody(); + const refreshBtn = DOM.refreshMonitorBtn(); + if (refreshBtn) setButtonLoading(refreshBtn, true, '刷新中...', refreshBtn.querySelector('i'), 'fa-sync-alt'); + + tBody.innerHTML = ``; + + setTimeout(() => { // 模拟加载 + const deployments = JSON.parse(localStorage.getItem(LOCAL_STORAGE_DEPLOYMENTS_KEY) || '[]'); + + if (!deployments.length) { + tBody.innerHTML = `暂无已部署模型 (本地模拟)。`; + } else { + tBody.innerHTML = deployments.map(d => ` + + ${escapeHtml(d.model_display_name)} + ${escapeHtml(d.environment)} + ${escapeHtml(d.endpoint)} + ${escapeHtml(d.status)} + + + + `).join(''); + } + // 更新模拟统计 + DOM.deployedModelCount().textContent = deployments.length; + const runningDeployments = deployments.filter(d => d.status === '运行中'); + if (runningDeployments.length > 0) { + const totalResponseTime = runningDeployments.reduce((sum, d) => sum + (d.avg_response_ms || 100), 0); + DOM.avgResponseTime().textContent = `${Math.round(totalResponseTime / runningDeployments.length)}ms`; + DOM.predictionRequests().textContent = runningDeployments.reduce((sum, d) => sum + (d.requests || 0), 0); + } else { + DOM.avgResponseTime().textContent = 'N/A'; + DOM.predictionRequests().textContent = '0'; + } + + if (refreshBtn) setButtonLoading(refreshBtn, false, '刷新', refreshBtn.querySelector('i'), 'fa-sync-alt'); + }, 300); +} + +window.confirmUndeploy_localStorage = function(deploymentId, modelDisplayName) { + if (!confirm(`确定要取消模型 "${modelDisplayName}" (ID: ${deploymentId}) 的本地模拟部署吗?`)) return; + + let deployments = JSON.parse(localStorage.getItem(LOCAL_STORAGE_DEPLOYMENTS_KEY) || '[]'); + const initialLength = deployments.length; + deployments = deployments.filter(d => d.id !== deploymentId); + + if (deployments.length < initialLength) { + localStorage.setItem(LOCAL_STORAGE_DEPLOYMENTS_KEY, JSON.stringify(deployments)); + showToast(`模型 "${modelDisplayName}" 的模拟部署已取消。`, 'success'); + fetchAndDisplayDeployedModels_localStorage(); // 刷新列表和统计 + } else { + showToast('未找到要取消的模拟部署。', 'error'); + } +}; +// function initModelVersioning() { +// const selector = DOM.versionModelSelector(); +// const createBtn = DOM.createVersionBtn(); +// const saveBtn = DOM.saveVersionBtn(); +// const cancelBtn = DOM.cancelSaveVersionBtn(); +// if (!selector || !createBtn || !saveBtn || !cancelBtn) return; +// +// selector.addEventListener('change', async (e) => { +// currentModelForVersioning = e.target.value; +// DOM.versionMetadataForm().classList.add('hidden'); +// if (currentModelForVersioning) await fetchAndDisplayModelVersions(currentModelForVersioning); +// else DOM.versionTableBody().innerHTML = `选择模型查看版本`; +// }); +// createBtn.addEventListener('click', () => { +// if (!currentModelForVersioning) { showToast('请先选择模型。', 'warning'); return; } +// DOM.versionFormModelName().textContent = getModelDisplayName(currentModelForVersioning); +// DOM.versionDescription().value = ''; DOM.versionPerformance().value = ''; +// DOM.versionMetadataForm().classList.remove('hidden'); DOM.versionDescription().focus(); +// }); +// cancelBtn.addEventListener('click', () => DOM.versionMetadataForm().classList.add('hidden')); +// saveBtn.addEventListener('click', async () => { +// if (!currentModelForVersioning) return; +// const desc = DOM.versionDescription().value.trim(); +// const perf = DOM.versionPerformance().value.trim(); +// if (!desc && !perf) { showToast('请输入版本描述或性能指标。', 'warning'); return; } +// setButtonLoading(saveBtn, true); +// try { +// // 调用创建模型版本API +// const response = await fetch(API_ENDPOINTS.MODEL_VERSIONS, { +// method: 'POST', +// headers: { 'Content-Type': 'application/json' }, +// body: JSON.stringify({ +// model_name: currentModelForVersioning, +// description: desc, +// performance_metrics: perf +// }) +// }); +// const result = await response.json(); +// if (result.error) throw new Error(result.error); +// +// showToast(result.message || `模型 "${getModelDisplayName(currentModelForVersioning)}" 新版本已创建。`, 'success'); +// DOM.versionMetadataForm().classList.add('hidden'); +// await fetchAndDisplayModelVersions(currentModelForVersioning); +// } catch (error) { showToast(`创建版本失败: ${error.message}`, 'error'); } +// finally { setButtonLoading(saveBtn, false); } +// }); +// +// } async function fetchAndDisplayModelVersions(modelName) { const tBody = DOM.versionTableBody(); tBody.innerHTML = ``; @@ -1717,221 +1995,453 @@ async function fetchAndDisplayModelVersions(modelName) { const response = await fetch(`${API_ENDPOINTS.GET_MODEL_VERSIONS}${encodeURIComponent(modelName)}`); const result = await response.json(); if (result.error) throw new Error(result.error); - + const versions = result.versions || []; - if (!versions.length) { - tBody.innerHTML = `无版本记录。`; - return; + if (!versions.length) { + tBody.innerHTML = `无版本记录。`; + return; } - + tBody.innerHTML = versions.map(v => ` ${escapeHtml(v.id)}${new Date(v.created_at).toLocaleString()}${escapeHtml(v.description || '-')}${escapeHtml(v.performance_metrics || '-')} `).join(''); - } catch (e) { + } catch (e) { console.error('获取版本失败:', e); - showToast(`获取版本失败: ${e.message}`, 'error'); - tBody.innerHTML = `加载版本失败`; + showToast(`获取版本失败: ${e.message}`, 'error'); + tBody.innerHTML = `加载版本失败`; } } // --- MODEL COMPARISON --- let compareModelCount = 0; const MAX_COMPARE = 3; +// function initModelComparison() { +// const addBtn = DOM.addCompareModelBtn(); +// const startBtn = DOM.startCompareBtn(); +// const dataSel = DOM.compareTestDataSelect(); +// const targetSel = DOM.compareTargetColumnSelect(); +// if (!addBtn || !startBtn) return; +// +// // 加载默认数据集 +// loadDefaultDatasets(); +// +// /** +// * 加载默认数据集 +// */ +// function loadDefaultDatasets() { +// const dataSelect = DOM.compareTestDataSelect(); +// if (dataSelect) { +// dataSelect.innerHTML = ` +// +// +// `; +// } +// } +// +// +// +// // 确保在初始化时已经加载了模型列表 +// const initializeModelSelectors = async () => { +// // 如果模型缓存为空,先获取模型列表 +// if (!allModelsCache || allModelsCache.length === 0) { +// try { +// const response = await fetch(API_ENDPOINTS.MODELS); +// const data = await response.json(); +// allModelsCache = data.models || []; +// } catch (e) { +// console.error("获取模型列表失败:", e); +// showToast("获取模型列表失败,请刷新页面重试", "error"); +// } +// } +// +// // 添加第一个模型选择器 +// addDynamicModelSelector( +// DOM.compareModelsContainer(), +// DOM.compareModelPlaceholder(), +// allModelsCache, +// compareModelCount, +// MAX_COMPARE, +// 'compare-model-select', +// '比较模型', +// (nc) => compareModelCount = nc +// ); +// }; +// +// // 初始化模型选择器 +// initializeModelSelectors(); +// +// // 添加按钮事件监听 +// addBtn.addEventListener('click', () => addDynamicModelSelector( +// DOM.compareModelsContainer(), +// DOM.compareModelPlaceholder(), +// allModelsCache, +// compareModelCount, +// MAX_COMPARE, +// 'compare-model-select', +// '比较模型', +// (newCount) => compareModelCount = newCount +// )); +// +// startBtn.addEventListener('click', async () => { +// const models = Array.from(DOM.compareModelsContainer().querySelectorAll('.compare-model-select')).map(s => s.value).filter(Boolean); +// const testData = dataSel.value; const target = targetSel.value; +// if (models.length < 2) { showToast('请至少选择两个模型。', 'warning'); return; } +// if (!testData) { showToast('请选择测试数据集。', 'warning'); return; } +// if (testData === 'current_uploaded' && (!currentData.path || !currentData.analysisCompleted || !target)) { showToast('当前数据未就绪或未选目标列。', 'warning'); return; } +// DOM.compareResultsContainer().innerHTML = `

`; +// setButtonLoading(startBtn, true); +// try { +// // 调用模型比较API +// const response = await fetch(API_ENDPOINTS.COMPARE_MODELS, { +// method: 'POST', +// headers: { 'Content-Type': 'application/json' }, +// body: JSON.stringify({ +// model_names: models, +// test_data_path: testData, +// target_column: target +// }) +// }); +// const result = await response.json(); +// if (result.error) throw new Error(result.error); +// +// // 显示比较结果 +// let html = `

比较结果

`; +// +// // 确定所有可能的指标 +// const allMetrics = new Set(); +// result.comparison_results.forEach(model => { +// if (model.metrics) { +// Object.keys(model.metrics).forEach(metric => allMetrics.add(metric)); +// } +// }); +// +// // 添加指标列 +// allMetrics.forEach(metric => { +// html += ``; +// }); +// +// html += ``; +// +// // 添加每个模型的结果行 +// result.comparison_results.forEach((model, index) => { +// html += ``; +// +// allMetrics.forEach(metric => { +// const value = model.metrics && model.metrics[metric] !== undefined ? +// model.metrics[metric].toFixed(4) : '-'; +// html += ``; +// }); +// +// html += ``; +// }); +// +// html += `
模型${formatMetricName(metric)}
${model.model_name}${value}
`; +// +// // 添加测试数据信息 +// if (result.test_data) { +// html += `
+//

测试数据信息

+//

路径: ${result.test_data.path}

+//

行数: ${result.test_data.rows}

+//

列数: ${result.test_data.columns}

+//
`; +// } +// +// html += `
`; +// +// DOM.compareResultsContainer().innerHTML = html; +// } catch (e) { +// console.error('比较模型错误:', e); +// showToast(`比较失败: ${e.message}`, 'error'); +// DOM.compareResultsContainer().innerHTML = `

比较失败: ${e.message}

`; +// } +// finally { setButtonLoading(startBtn, false); } +// }); +// +// dataSel.addEventListener('change', async () => { +// // 当选择测试数据集时,加载相应的目标列 +// if (!dataSel.value) return; +// +// targetSel.innerHTML = ''; +// +// if (dataSel.value === 'current_uploaded') { +// // 使用当前上传的数据集的列 +// populateSelectWithOptions(targetSel, currentData.columns, "选择目标列"); +// } else { +// // 从服务器获取数据集的列 +// try { +// const response = await fetch(`/api/ml/analyze?file_path=${encodeURIComponent(dataSel.value)}`); +// const result = await response.json(); +// if (result.error) throw new Error(result.error); +// +// const columns = result.columns || []; +// populateSelectWithOptions(targetSel, columns, "选择目标列"); +// } catch (e) { +// console.error('获取数据集列失败:', e); +// populateSelectWithOptions(targetSel, [], "获取列失败"); +// showToast(`获取数据集列失败: ${e.message}`, 'error'); +// } +// } +// }); +// +/** + * 用选项填充选择器 + */ +function populateSelectWithOptions(selectElement, options, placeholderText = "请选择") { + if (!selectElement) return; + + // 清除现有选项 + selectElement.innerHTML = ''; + + // 添加占位符选项 + const placeholderOption = document.createElement('option'); + placeholderOption.value = ''; + placeholderOption.textContent = placeholderText; + placeholderOption.disabled = true; + placeholderOption.selected = true; + selectElement.appendChild(placeholderOption); + + // 添加所有选项 + options.forEach(option => { + const opt = document.createElement('option'); + opt.value = option; + opt.textContent = option; + selectElement.appendChild(opt); + }); +} + function initModelComparison() { const addBtn = DOM.addCompareModelBtn(); const startBtn = DOM.startCompareBtn(); const dataSel = DOM.compareTestDataSelect(); const targetSel = DOM.compareTargetColumnSelect(); - if (!addBtn || !startBtn) return; - - // 加载默认数据集 - loadDefaultDatasets(); -/** - * 加载默认数据集 - */ -function loadDefaultDatasets() { - const dataSelect = DOM.compareTestDataSelect(); - if (dataSelect) { - dataSelect.innerHTML = ` - - - `; - } -} - - + if (!addBtn || !startBtn || !dataSel || !targetSel) return; // 确保所有元素存在 - // 确保在初始化时已经加载了模型列表 - const initializeModelSelectors = async () => { - // 如果模型缓存为空,先获取模型列表 + // 加载默认数据集选项 (你已有的) + loadDefaultDatasetsForComparison(); // 重命名以区分 + + // 动态添加模型选择器逻辑 (你已有的 addDynamicModelSelector) + // 确保 allModelsCache 是从 FIXED_MODEL_DETAILS (或后端 /api/ml/models 如果你更倾向) 填充 + const initializeCompareModelSelectors = async () => { if (!allModelsCache || allModelsCache.length === 0) { - try { - const response = await fetch(API_ENDPOINTS.MODELS); - const data = await response.json(); - allModelsCache = data.models || []; - } catch (e) { - console.error("获取模型列表失败:", e); - showToast("获取模型列表失败,请刷新页面重试", "error"); - } + // 可以从 FIXED_MODEL_DETAILS 初始化 allModelsCache + allModelsCache = Object.values(FIXED_MODEL_DETAILS).map(m => ({ + internal_name: m.internal_name, + name: m.display_name, // 或者 internal_name,取决于 populateModelSelector + type: getCategoryForModel(m.internal_name) // 确保 getCategoryForModel 可用 + })); + // 或者从后端API获取,如果后端 /api/ml/models 提供了完整的模型列表 + // try { + // const response = await fetch(API_ENDPOINTS.MODELS); + // const data = await response.json(); + // allModelsCache = data.models || []; + // } catch (e) { console.error("获取模型列表失败:", e); } + } + // 至少添加一个比较模型选择器 + if (DOM.compareModelsContainer().querySelectorAll('.compare-model-select').length === 0) { + addDynamicModelSelector( + DOM.compareModelsContainer(), + DOM.compareModelPlaceholder(), + allModelsCache, // 传递 allModelsCache + compareModelCount, // 全局变量 + MAX_COMPARE, + 'compare-model-select', + '比较模型', + (nc) => compareModelCount = nc + ); } - - // 添加第一个模型选择器 - addDynamicModelSelector( - DOM.compareModelsContainer(), - DOM.compareModelPlaceholder(), - allModelsCache, - compareModelCount, - MAX_COMPARE, - 'compare-model-select', - '比较模型', - (nc) => compareModelCount = nc - ); }; - - // 初始化模型选择器 - initializeModelSelectors(); - - // 添加按钮事件监听 + initializeCompareModelSelectors(); + + addBtn.addEventListener('click', () => addDynamicModelSelector( - DOM.compareModelsContainer(), - DOM.compareModelPlaceholder(), - allModelsCache, - compareModelCount, - MAX_COMPARE, - 'compare-model-select', + DOM.compareModelsContainer(), + DOM.compareModelPlaceholder(), + allModelsCache, // 确保 allModelsCache 已填充 + compareModelCount, + MAX_COMPARE, + 'compare-model-select', '比较模型', (newCount) => compareModelCount = newCount )); startBtn.addEventListener('click', async () => { - const models = Array.from(DOM.compareModelsContainer().querySelectorAll('.compare-model-select')).map(s => s.value).filter(Boolean); - const testData = dataSel.value; const target = targetSel.value; - if (models.length < 2) { showToast('请至少选择两个模型。', 'warning'); return; } - if (!testData) { showToast('请选择测试数据集。', 'warning'); return; } - if (testData === 'current_uploaded' && (!currentData.path || !currentData.analysisCompleted || !target)) { showToast('当前数据未就绪或未选目标列。', 'warning'); return; } - DOM.compareResultsContainer().innerHTML = `

`; - setButtonLoading(startBtn, true); + const selectedModelsForCompare = Array.from(DOM.compareModelsContainer().querySelectorAll('.compare-model-select')) + .map(s => s.value) + .filter(Boolean); + const testDataIdentifier = dataSel.value; + const targetColumnForCompare = targetSel.value; + + if (selectedModelsForCompare.length < 2) { + showToast('请至少选择两个模型进行比较。', 'warning'); return; + } + if (!testDataIdentifier) { + showToast('请选择测试数据集。', 'warning'); return; + } + // 如果选择的是 "current_uploaded",则确保已上传数据且选了目标列 + if (testDataIdentifier === 'current_uploaded') { + if (!currentData.path || !currentData.analysisCompleted) { + showToast('当前上传的数据未就绪,请先上传并分析数据。', 'warning'); return; + } + if (!selectedTargetColumn && !targetColumnForCompare) { // 检查是否在主对话区选了目标列 + showToast('请为当前上传的数据选择一个目标列以进行比较。', 'warning'); return; + } + } + if (!targetColumnForCompare && testDataIdentifier !== 'current_uploaded') { + showToast('请选择此数据集的目标列。', 'warning'); return; + } + if (!targetColumnForCompare && testDataIdentifier === 'current_uploaded' && selectedTargetColumn) { + // 如果比较工具没选目标列,但主对话区选了,可以用主对话区的 + // 或者强制用户在比较工具里也选一次 + } + + + DOM.compareResultsContainer().innerHTML = `

`; + setButtonLoading(startBtn, true, '正在模拟比较...'); + try { - // 调用模型比较API - const response = await fetch(API_ENDPOINTS.COMPARE_MODELS, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - model_names: models, - test_data_path: testData, - target_column: target - }) + const body = { + model_names: selectedModelsForCompare, + test_data_identifier: testDataIdentifier === 'current_uploaded' ? `当前上传的数据 (${currentData.fileName || '未命名'})` : testDataIdentifier, + target_column: targetColumnForCompare || selectedTargetColumn // 优先用比较工具选的,否则用主对话区选的 + }; + + const response = await fetch(API_ENDPOINTS.SIMULATE_COMPARE_MODELS, { // 使用模拟API + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body) }); + + if (!response.ok) { + const errData = await response.json().catch(() => ({ error: `服务器错误 (${response.status})` })); + throw new Error(errData.error || `模拟比较失败: ${response.statusText}`); + } + const result = await response.json(); - if (result.error) throw new Error(result.error); - - // 显示比较结果 - let html = `

比较结果

`; - - // 确定所有可能的指标 + if (result.error) { + let displayError = result.error; + if(result.raw_llm_response) displayError += `
LLM原始响应: ${escapeHtml(result.raw_llm_response.substring(0,200))}...`; + DOM.compareResultsContainer().innerHTML = `

${displayError}

`; + throw new Error(result.error); + } + + // 显示模拟的比较结果 (您已有的显示逻辑,可能需要调整以匹配后端模拟API的返回结构) + let html = `

模型模拟比较结果

`; + if (result.test_data_info) { + html += `

基于数据集: ${escapeHtml(result.test_data_info.identifier)} (模拟行数: ${escapeHtml(result.test_data_info.simulated_rows || 'N/A')}, 模拟特征数: ${escapeHtml(result.test_data_info.simulated_features || 'N/A')})

`; + } + html += `
模型
`; + const allMetrics = new Set(); - result.comparison_results.forEach(model => { - if (model.metrics) { - Object.keys(model.metrics).forEach(metric => allMetrics.add(metric)); - } - }); - - // 添加指标列 - allMetrics.forEach(metric => { - html += ``; - }); - + if (result.comparison_results && Array.isArray(result.comparison_results)) { + result.comparison_results.forEach(modelRes => { + if (modelRes.metrics && typeof modelRes.metrics === 'object') { + Object.keys(modelRes.metrics).forEach(metric => allMetrics.add(metric)); + } + }); + } + + allMetrics.forEach(metric => html += ``); html += ``; - - // 添加每个模型的结果行 - result.comparison_results.forEach((model, index) => { - html += ``; - - allMetrics.forEach(metric => { - const value = model.metrics && model.metrics[metric] !== undefined ? - model.metrics[metric].toFixed(4) : '-'; - html += ``; + + if (result.comparison_results && Array.isArray(result.comparison_results)) { + result.comparison_results.forEach(modelRes => { + html += ``; // 使用 getModelDisplayName + allMetrics.forEach(metric => { + const metricValue = modelRes.metrics ? modelRes.metrics[metric] : '-'; + html += ``; + }); + html += ``; }); - - html += ``; - }); - - html += `
模型${formatMetricName(metric)}${escapeHtml(formatMetricName(metric))}
${model.model_name}${value}
${escapeHtml(getModelDisplayName(modelRes.model_name))}${escapeHtml(typeof metricValue === 'number' ? metricValue.toFixed(4) : metricValue)}
`; - - // 添加测试数据信息 - if (result.test_data) { - html += `
-

测试数据信息

-

路径: ${result.test_data.path}

-

行数: ${result.test_data.rows}

-

列数: ${result.test_data.columns}

-
`; + } else { + html += `未能获取有效的比较结果数据。`; + } + html += `
`; + + if (result.summary) { + html += `
模拟总结:

${formatAnswer(result.summary)}

`; } - html += ``; - DOM.compareResultsContainer().innerHTML = html; - } catch (e) { - console.error('比较模型错误:', e); - showToast(`比较失败: ${e.message}`, 'error'); - DOM.compareResultsContainer().innerHTML = `

比较失败: ${e.message}

`; + showToast('模型模拟比较完成!', 'success'); + + } catch (e) { + console.error('模拟模型比较错误:', e); + if (!DOM.compareResultsContainer().innerHTML.includes('text-error')) { //避免重复显示错误 + DOM.compareResultsContainer().innerHTML = `

模拟比较失败: ${escapeHtml(e.message)}

`; + } + showToast(`模拟比较失败: ${e.message}`, 'error'); + } finally { + setButtonLoading(startBtn, false, ' 开始比较'); } - finally { setButtonLoading(startBtn, false); } }); + // dataSel 事件监听器 (您已有的) - 确保它能正确填充 targetSel dataSel.addEventListener('change', async () => { - // 当选择测试数据集时,加载相应的目标列 - if (!dataSel.value) return; - + const selectedDataset = dataSel.value; targetSel.innerHTML = ''; - - if (dataSel.value === 'current_uploaded') { - // 使用当前上传的数据集的列 - populateSelectWithOptions(targetSel, currentData.columns, "选择目标列"); + targetSel.disabled = true; + + if (!selectedDataset) { + populateSelectWithOptions(targetSel, [], "先选择数据集"); + return; + } + + if (selectedDataset === 'current_uploaded') { + if (currentData.columns && currentData.columns.length > 0) { + populateSelectWithOptions(targetSel, currentData.columns, "选择目标列"); + targetSel.disabled = false; + } else { + populateSelectWithOptions(targetSel, [], "当前无可用数据列"); + showToast('当前未上传数据或数据无列信息。', 'warning'); + } } else { - // 从服务器获取数据集的列 - try { - const response = await fetch(`/api/ml/analyze?file_path=${encodeURIComponent(dataSel.value)}`); - const result = await response.json(); - if (result.error) throw new Error(result.error); - - const columns = result.columns || []; - populateSelectWithOptions(targetSel, columns, "选择目标列"); - } catch (e) { - console.error('获取数据集列失败:', e); - populateSelectWithOptions(targetSel, [], "获取列失败"); - showToast(`获取数据集列失败: ${e.message}`, 'error'); + // 对于预设数据集,后端没有提供直接获取列的API,除非 /api/ml/analyze 支持按名称分析 + // 暂时,我们可以假设这些预设数据集的目标列是已知的,或者让用户手动输入 + // 为了演示,可以填充一些示例列或提示用户 + showToast(`选择预设数据集 "${selectedDataset}",请确保目标列适用于此数据集。`, 'info'); + // 假设有一些已知列,或者清空让用户注意 + // populateSelectWithOptions(targetSel, ['示例目标列1', '示例目标列2'], "选择或确认目标列"); + // 更好的做法是,如果这些是真实数据集,后端应该能提供它们的列信息 + // 如果后端 /api/ml/analyze 可以接受数据集标识符并返回列,则可以在这里调用 + // 否则,目标列选择器对于非 current_uploaded 数据集将作用有限 + // 为简单起见,我们先假设如果不是 current_uploaded,用户知道目标列是什么 + // 或者,可以从currentData.columns填充,让用户选择一个,并提示这可能不适用于所选数据集。 + if (currentData.columns && currentData.columns.length > 0) { + populateSelectWithOptions(targetSel, currentData.columns, "选择目标列 (可能不适用)"); + targetSel.disabled = false; + } else { + populateSelectWithOptions(targetSel, [], "无列信息参考"); } } }); - -/** - * 用选项填充选择器 - */ -function populateSelectWithOptions(selectElement, options, placeholderText = "请选择") { - if (!selectElement) return; - - // 清除现有选项 - selectElement.innerHTML = ''; - - // 添加占位符选项 - const placeholderOption = document.createElement('option'); - placeholderOption.value = ''; - placeholderOption.textContent = placeholderText; - placeholderOption.disabled = true; - placeholderOption.selected = true; - selectElement.appendChild(placeholderOption); - - // 添加所有选项 - options.forEach(option => { - const opt = document.createElement('option'); - opt.value = option; - opt.textContent = option; - selectElement.appendChild(opt); - }); -} } +// 辅助函数:为比较工具的下拉列表加载数据集名称 +function loadDefaultDatasetsForComparison() { + const dataSelect = DOM.compareTestDataSelect(); + if (dataSelect) { + // 保留用户已有的选项,这里只确保 "当前上传数据" 存在 + let currentUploadedOptionExists = false; + for (let i = 0; i < dataSelect.options.length; i++) { + if (dataSelect.options[i].value === 'current_uploaded') { + currentUploadedOptionExists = true; + break; + } + } + if (!currentUploadedOptionExists) { + const opt = document.createElement('option'); + opt.value = 'current_uploaded'; + opt.textContent = '当前已上传的数据集'; + // 插入到 "选择数据集" 之后 + if (dataSelect.options.length > 1) { + dataSelect.insertBefore(opt, dataSelect.options[1]); + } else { + dataSelect.appendChild(opt); + } + } + } +} /** * 格式化指标名称 */ @@ -1950,145 +2460,304 @@ function formatMetricName(metric) { } // --- ENSEMBLE BUILDING --- -let ensembleModelCount = 0; const MIN_ENSEMBLE = 2; +// app.js + +// ... (其他代码) ... + function initEnsembleBuilding() { const addBtn = DOM.addEnsembleModelBtn(); const buildBtn = DOM.buildEnsembleBtn(); - if (!addBtn || !buildBtn) return; - - // 确保模型列表已加载 - populateAdvancedToolSelectors(); - - // 加载所有可用模型并初始化选择器 - setTimeout(async () => { - if (!allModelsCache || allModelsCache.length === 0) { - try { - const response = await fetch(API_ENDPOINTS.MODELS); - if (!response.ok) throw new Error(`请求失败 (${response.status})`); - const data = await response.json(); - allModelsCache = data.models || []; - if (allModelsCache.length === 0) { - console.warn('没有可用的模型'); - return; - } - } catch (error) { - console.error('获取模型列表失败:', error); - return; - } - } - - // 初始化选择器值 - const selectors = document.querySelectorAll('.ensemble-model-select'); - selectors.forEach((selector, index) => { - if (index < allModelsCache.length) { - const model = allModelsCache[index]; - selector.value = model.internal_name || model.name; - selector.title = model.description || model.display_name || model.name; - } - }); - }, 500); + const ensembleTypeSel = DOM.ensembleTypeSelect(); + const ensembleNameInput = DOM.ensembleName(); - addBtn.addEventListener('click', () => { + if (!addBtn || !buildBtn || !ensembleTypeSel || !ensembleNameInput) return; + + // 动态添加模型选择器逻辑 (您已有的) + // 确保 allModelsCache 填充方式同上 + const initializeEnsembleModelSelectors = async () => { if (!allModelsCache || allModelsCache.length === 0) { - loadAvailableModels().then(() => { - addDynamicModelSelector( - DOM.ensembleModelsContainer(), - DOM.ensembleModelPlaceholder(), - allModelsCache.filter(m => m.type !== 'ensemble'), - ensembleModelCount, - 10, - 'ensemble-model-select', - '基础模型', - (newCount) => ensembleModelCount = newCount - ); - }); - } else { + allModelsCache = Object.values(FIXED_MODEL_DETAILS).map(m => ({ + internal_name: m.internal_name, + name: m.display_name, + type: getCategoryForModel(m.internal_name) + })); + } + // 初始化时添加 MIN_ENSEMBLE (例如2个) 模型选择器 + const existingSelectors = DOM.ensembleModelsContainer().querySelectorAll('.ensemble-model-select').length; + for (let i = existingSelectors; i < MIN_ENSEMBLE; i++) { addDynamicModelSelector( - DOM.ensembleModelsContainer(), - DOM.ensembleModelPlaceholder(), - allModelsCache.filter(m => m.type !== 'ensemble'), - ensembleModelCount, - 10, - 'ensemble-model-select', - '基础模型', - (newCount) => ensembleModelCount = newCount + DOM.ensembleModelsContainer(), + DOM.ensembleModelPlaceholder(), + allModelsCache.filter(m => m.type !== 'ensemble'), // 基础模型不能是集成模型自身 + ensembleModelCount, // 全局变量 + 10, // Max ensemble components + 'ensemble-model-select', + '基础模型', + (nc) => ensembleModelCount = nc ); } + }; + initializeEnsembleModelSelectors(); + + + addBtn.addEventListener('click', () => { + addDynamicModelSelector( + DOM.ensembleModelsContainer(), + DOM.ensembleModelPlaceholder(), + allModelsCache.filter(m => m.type !== 'ensemble'), // 基础模型不应是集成模型 + ensembleModelCount, + 10, // Max ensemble components + 'ensemble-model-select', + '基础模型', + (newCount) => ensembleModelCount = newCount + ); }); - // 初始化添加两个基础模型选择器 - for(let i=0; i m.type !== 'ensemble'), ensembleModelCount, 10, 'ensemble-model-select', '基础模型', (nc) => ensembleModelCount = nc); buildBtn.addEventListener('click', async () => { - const models = Array.from(DOM.ensembleModelsContainer().querySelectorAll('.ensemble-model-select')).map(s => s.value).filter(Boolean); - const type = DOM.ensembleTypeSelect().value; - const name = DOM.ensembleName().value.trim(); - - // 验证输入 - if (models.length < MIN_ENSEMBLE) { - showToast(`请至少选择 ${MIN_ENSEMBLE} 个基础模型。`, 'warning'); - return; + const selectedBaseModels = Array.from(DOM.ensembleModelsContainer().querySelectorAll('.ensemble-model-select')) + .map(s => s.value) + .filter(Boolean); + const ensembleType = ensembleTypeSel.value; + const ensembleName = ensembleNameInput.value.trim(); + + if (selectedBaseModels.length < MIN_ENSEMBLE) { + showToast(`请至少选择 ${MIN_ENSEMBLE} 个基础模型。`, 'warning'); return; } - if (!type) { - showToast('请选择集成类型。', 'warning'); - return; + if (!ensembleType) { + showToast('请选择集成类型。', 'warning'); return; } - if (!name) { - showToast('请输入集成模型名称。', 'warning'); - return; + if (!ensembleName) { + showToast('请输入集成模型名称。', 'warning'); return; } - if (!/^[a-zA-Z0-9_.-]+$/.test(name)) { - showToast('模型名称只能包含字母、数字、下划线、点和连字符。', 'warning'); - return; + if (!/^[a-zA-Z0-9_.-]+$/.test(ensembleName)) { + showToast('模型名称只能包含字母、数字、下划线、点和连字符。', 'warning'); return; } - - // 显示加载状态 - DOM.ensembleResultContainer().innerHTML = `

`; - setButtonLoading(buildBtn, true, '构建中...'); - + + DOM.ensembleResultContainer().innerHTML = `

`; + setButtonLoading(buildBtn, true, '正在模拟构建...'); + try { - // 调用构建集成模型API - const response = await fetch(API_ENDPOINTS.BUILD_ENSEMBLE, { + const body = { + base_models: selectedBaseModels, + ensemble_type: ensembleType, + ensemble_name: ensembleName + }; + + const response = await fetch(API_ENDPOINTS.SIMULATE_BUILD_ENSEMBLE, { // 使用模拟API method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - base_models: models, - ensemble_type: type, - save_name: name - }) + body: JSON.stringify(body) }); - + if (!response.ok) { - const errorData = await response.json(); - throw new Error(errorData.error || `请求失败 (${response.status})`); + const errData = await response.json().catch(() => ({ error: `服务器错误 (${response.status})` })); + throw new Error(errData.error || `模拟构建失败: ${response.statusText}`); } - const result = await response.json(); - if (!result.success) { - throw new Error(result.error || '未知错误'); + + if (result.error || !result.success) { + let displayError = result.error || "模拟构建时返回了未知错误。"; + if(result.raw_llm_response) displayError += `
LLM原始响应: ${escapeHtml(result.raw_llm_response.substring(0,200))}...`; + DOM.ensembleResultContainer().innerHTML = `

${displayError}

`; + throw new Error(displayError); } - - // 显示成功结果 - DOM.ensembleResultContainer().innerHTML = ` -
-

集成模型 "${escapeHtml(name)}" 构建成功!

-
-

模型类型: ${result.ensemble_type} 集成

-

基础模型: ${result.base_models.map(m => getModelDisplayName(m)).join(', ')}

-

创建时间: ${new Date(result.model_info.created_at).toLocaleString()}

-
-
`; - - showToast('集成模型构建成功!', 'success'); - await loadAvailableModels(); // 刷新所有模型列表 - } catch (e) { - console.error('构建集成模型错误:', e); - showToast(`构建失败: ${e.message}`, 'error'); - DOM.ensembleResultContainer().innerHTML = `

构建失败: ${escapeHtml(e.message)}

`; - } finally { - setButtonLoading(buildBtn, false, ' 构建集成模型'); + + // 显示模拟的集成模型构建结果 + let html = `

集成模型模拟构建结果

`; + html += `

${escapeHtml(result.message || `集成模型 "${escapeHtml(result.ensemble_name)}" 已成功模拟创建!`)}

`; + html += `
`; + if(result.ensemble_name) html += `

集成模型名称: ${escapeHtml(result.ensemble_name)}

`; + if(result.ensemble_type) html += `

集成类型: ${escapeHtml(result.ensemble_type)}

`; + if(result.base_models_used && result.base_models_used.length) { + html += `

基础模型: ${result.base_models_used.map(m => escapeHtml(getModelDisplayName(m))).join(', ')}

`; + } + if(result.description) html += `
模拟工作原理:
${formatAnswer(result.description)}
`; + if(result.potential_advantages) html += `
模拟潜在优势:
${formatAnswer(result.potential_advantages)}
`; + if(result.suitable_scenarios) html += `
模拟适用场景:
${formatAnswer(result.suitable_scenarios)}
`; + if(result.model_info) { + html += `
模拟元数据:`; + html += `
    `; + if(result.model_info.simulated_created_at) html += `
  • 创建时间: ${escapeHtml(new Date(result.model_info.simulated_created_at).toLocaleString())}
  • `; + if(result.model_info.simulated_combination_method) html += `
  • 组合方法: ${escapeHtml(result.model_info.simulated_combination_method)}
  • `; + html += `
`; + } + html += `
`; + DOM.ensembleResultContainer().innerHTML = html; + showToast('集成模型模拟构建完成!', 'success'); + + // 可以在这里考虑是否将这个模拟的集成模型添加到某个前端列表,或者提示用户它已“创建” + // 例如,可以更新 allModelsCache 或 FIXED_MODEL_DETAILS (如果希望它出现在模型选择中) + // 但要注意,这只是一个模拟,它没有真实的模型文件。 + + } catch (e) { + console.error('模拟集成模型构建错误:', e); + if (!DOM.ensembleResultContainer().innerHTML.includes('text-error')) { + DOM.ensembleResultContainer().innerHTML = `

模拟构建失败: ${escapeHtml(e.message)}

`; + } + showToast(`模拟构建失败: ${e.message}`, 'error'); + } finally { + setButtonLoading(buildBtn, false, ' 构建集成模型'); } }); } +async function populateAdvancedToolSelectors() { // modelsList 参数现在可选 + if (!allModelsCache || allModelsCache.length === 0) { + // 使用 FIXED_MODEL_DETAILS 初始化 allModelsCache + allModelsCache = Object.values(FIXED_MODEL_DETAILS).map(m => ({ + internal_name: m.internal_name, + name: m.display_name, // 或者 m.internal_name 用于 populateModelSelector + type: getCategoryForModel(m.internal_name) + })); + // 如果您仍希望从后端获取模型列表用于高级工具,可以保留 fetch(API_ENDPOINTS.MODELS) 的逻辑 + } + // 确保 DOM 元素存在再填充 + if (DOM.versionModelSelector()) { + populateModelSelector(DOM.versionModelSelector(), allModelsCache.map(m=> ({name: m.name, internal_name: m.internal_name})), "选择模型查看版本"); + } + if (DOM.deployModelSelect()) { + populateModelSelector(DOM.deployModelSelect(), allModelsCache.map(m=> ({name: m.name, internal_name: m.internal_name})), "选择要部署的模型"); + } +} +let ensembleModelCount = 0; const MIN_ENSEMBLE = 2; + +// function initEnsembleBuilding() { +// const addBtn = DOM.addEnsembleModelBtn(); +// const buildBtn = DOM.buildEnsembleBtn(); +// if (!addBtn || !buildBtn) return; +// +// // 确保模型列表已加载 +// populateAdvancedToolSelectors(); +// +// // 加载所有可用模型并初始化选择器 +// setTimeout(async () => { +// if (!allModelsCache || allModelsCache.length === 0) { +// try { +// const response = await fetch(API_ENDPOINTS.MODELS); +// if (!response.ok) throw new Error(`请求失败 (${response.status})`); +// const data = await response.json(); +// allModelsCache = data.models || []; +// if (allModelsCache.length === 0) { +// console.warn('没有可用的模型'); +// return; +// } +// } catch (error) { +// console.error('获取模型列表失败:', error); +// return; +// } +// } +// +// // 初始化选择器值 +// const selectors = document.querySelectorAll('.ensemble-model-select'); +// selectors.forEach((selector, index) => { +// if (index < allModelsCache.length) { +// const model = allModelsCache[index]; +// selector.value = model.internal_name || model.name; +// selector.title = model.description || model.display_name || model.name; +// } +// }); +// }, 500); +// +// addBtn.addEventListener('click', () => { +// if (!allModelsCache || allModelsCache.length === 0) { +// loadAvailableModels().then(() => { +// addDynamicModelSelector( +// DOM.ensembleModelsContainer(), +// DOM.ensembleModelPlaceholder(), +// allModelsCache.filter(m => m.type !== 'ensemble'), +// ensembleModelCount, +// 10, +// 'ensemble-model-select', +// '基础模型', +// (newCount) => ensembleModelCount = newCount +// ); +// }); +// } else { +// addDynamicModelSelector( +// DOM.ensembleModelsContainer(), +// DOM.ensembleModelPlaceholder(), +// allModelsCache.filter(m => m.type !== 'ensemble'), +// ensembleModelCount, +// 10, +// 'ensemble-model-select', +// '基础模型', +// (newCount) => ensembleModelCount = newCount +// ); +// } +// }); +// // 初始化添加两个基础模型选择器 +// for(let i=0; i m.type !== 'ensemble'), ensembleModelCount, 10, 'ensemble-model-select', '基础模型', (nc) => ensembleModelCount = nc); +// +// buildBtn.addEventListener('click', async () => { +// const models = Array.from(DOM.ensembleModelsContainer().querySelectorAll('.ensemble-model-select')).map(s => s.value).filter(Boolean); +// const type = DOM.ensembleTypeSelect().value; +// const name = DOM.ensembleName().value.trim(); +// +// // 验证输入 +// if (models.length < MIN_ENSEMBLE) { +// showToast(`请至少选择 ${MIN_ENSEMBLE} 个基础模型。`, 'warning'); +// return; +// } +// if (!type) { +// showToast('请选择集成类型。', 'warning'); +// return; +// } +// if (!name) { +// showToast('请输入集成模型名称。', 'warning'); +// return; +// } +// if (!/^[a-zA-Z0-9_.-]+$/.test(name)) { +// showToast('模型名称只能包含字母、数字、下划线、点和连字符。', 'warning'); +// return; +// } +// +// // 显示加载状态 +// DOM.ensembleResultContainer().innerHTML = `

`; +// setButtonLoading(buildBtn, true, '构建中...'); +// +// try { +// // 调用构建集成模型API +// const response = await fetch(API_ENDPOINTS.BUILD_ENSEMBLE, { +// method: 'POST', +// headers: { 'Content-Type': 'application/json' }, +// body: JSON.stringify({ +// base_models: models, +// ensemble_type: type, +// save_name: name +// }) +// }); +// +// if (!response.ok) { +// const errorData = await response.json(); +// throw new Error(errorData.error || `请求失败 (${response.status})`); +// } +// +// const result = await response.json(); +// if (!result.success) { +// throw new Error(result.error || '未知错误'); +// } +// +// // 显示成功结果 +// DOM.ensembleResultContainer().innerHTML = ` +//
+//

集成模型 "${escapeHtml(name)}" 构建成功!

+//
+//

模型类型: ${result.ensemble_type} 集成

+//

基础模型: ${result.base_models.map(m => getModelDisplayName(m)).join(', ')}

+//

创建时间: ${new Date(result.model_info.created_at).toLocaleString()}

+//
+//
`; +// +// showToast('集成模型构建成功!', 'success'); +// await loadAvailableModels(); // 刷新所有模型列表 +// } catch (e) { +// console.error('构建集成模型错误:', e); +// showToast(`构建失败: ${e.message}`, 'error'); +// DOM.ensembleResultContainer().innerHTML = `

构建失败: ${escapeHtml(e.message)}

`; +// } finally { +// setButtonLoading(buildBtn, false, ' 构建集成模型'); +// } +// }); +// } /** Generic function to add a model selector dynamically */ function addDynamicModelSelector(container, placeholderEl, modelsList, currentCount, maxCount, selectClass, labelPrefix, updateCountCallback) {