From 7db4c845d80ddcb8e5e049de7da403ff1efe743a Mon Sep 17 00:00:00 2001 From: littleniannian Date: Fri, 28 Nov 2025 14:06:46 +0800 Subject: [PATCH 1/6] feat: sql workbench audit --- internal/apiserver/service/router.go | 1 + .../service/sql_workbench_service.go | 656 ++++++++++++++++++ 2 files changed, 657 insertions(+) diff --git a/internal/apiserver/service/router.go b/internal/apiserver/service/router.go index ac69410e..991f2422 100644 --- a/internal/apiserver/service/router.go +++ b/internal/apiserver/service/router.go @@ -278,6 +278,7 @@ func (s *APIServer) initRouter() error { } sqlWorkbenchV1.Use(s.SqlWorkbenchController.SqlWorkbenchService.Login()) + sqlWorkbenchV1.Use(s.SqlWorkbenchController.SqlWorkbenchService.Intercept()) sqlWorkbenchV1.Use(middleware.ProxyWithConfig(middleware.ProxyConfig{ Skipper: middleware.DefaultSkipper, Balancer: middleware.NewRandomBalancer(targets), diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index bef57f71..4f4b81ee 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -1,8 +1,13 @@ package sql_workbench import ( + "bytes" + "compress/gzip" "context" + "encoding/base64" + "encoding/json" "fmt" + "io" "net/http" "net/url" "strings" @@ -13,9 +18,11 @@ import ( "github.com/actiontech/dms/internal/dms/biz" pkgConst "github.com/actiontech/dms/internal/dms/pkg/constant" "github.com/actiontech/dms/internal/dms/storage" + dbmodel "github.com/actiontech/dms/internal/dms/storage/model" "github.com/actiontech/dms/internal/sql_workbench/client" config "github.com/actiontech/dms/internal/sql_workbench/config" "github.com/actiontech/dms/pkg/dms-common/api/jwt" + _const "github.com/actiontech/dms/pkg/dms-common/pkg/const" pkgHttp "github.com/actiontech/dms/pkg/dms-common/pkg/http" utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log" "github.com/labstack/echo/v4" @@ -1003,3 +1010,652 @@ func makeHttpRequest(ctx context.Context, url string, headers map[string]string, } return nil } + +// intercept 拦截工作台odc请求进行加工 +func (sqlWorkbenchService *SqlWorkbenchService) Intercept() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // 只拦截包含 /streamExecute 的请求 + if !strings.Contains(c.Request().URL.Path, "/streamExecute") { + return next(c) + } + + // 读取请求体 + bodyBytes, err := io.ReadAll(c.Request().Body) + if err != nil { + sqlWorkbenchService.log.Errorf("Failed to read request body: %v", err) + return next(c) + } + // 恢复请求体,供后续处理使用 + c.Request().Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // 解析请求体获取 SQL 和 datasource ID + sql, datasourceID, err := sqlWorkbenchService.parseStreamExecuteRequest(bodyBytes) + if err != nil { + sqlWorkbenchService.log.Debugf("Failed to parse streamExecute request, skipping audit: %v", err) + return next(c) + } + + if sql == "" || datasourceID == "" { + sqlWorkbenchService.log.Debugf("SQL or datasource ID is empty, skipping audit") + return next(c) + } + + // 获取当前用户 ID + dmsUserId, err := sqlWorkbenchService.getDMSUserIdFromRequest(c) + if err != nil { + sqlWorkbenchService.log.Errorf("Failed to get DMS user ID: %v", err) + return next(c) + } + + // 从缓存表获取 dms_db_service_id + dmsDBServiceID, err := sqlWorkbenchService.getDMSDBServiceIDFromCache(c.Request().Context(), datasourceID, dmsUserId) + if err != nil { + sqlWorkbenchService.log.Errorf("Failed to get dms_db_service_id from cache: %v", err) + return next(c) + } + + if dmsDBServiceID == "" { + sqlWorkbenchService.log.Debugf("dms_db_service_id not found in cache for datasource: %s", datasourceID) + return next(c) + } + + // 获取 DBService 信息 + dbService, err := sqlWorkbenchService.dbServiceUsecase.GetDBService(c.Request().Context(), dmsDBServiceID) + if err != nil { + sqlWorkbenchService.log.Errorf("Failed to get DBService: %v", err) + return next(c) + } + + // 检查是否启用 SQL 审核 + if !sqlWorkbenchService.isEnableSQLAudit(dbService) { + sqlWorkbenchService.log.Debugf("SQL audit is not enabled for DBService: %s", dmsDBServiceID) + return next(c) + } + + // 调用 SQLE 审核接口 + auditResult, err := sqlWorkbenchService.callSQLEAudit(c.Request().Context(), sql, dbService) + if err != nil { + sqlWorkbenchService.log.Errorf("Failed to call SQLE audit: %v", err) + return next(c) + } + + // 拦截响应并添加审核结果 + return sqlWorkbenchService.interceptAndAddAuditResult(c, next, auditResult, dbService) + } + } +} + +// parseStreamExecuteRequest 解析 streamExecute 请求体,提取 SQL 和 datasource ID +func (sqlWorkbenchService *SqlWorkbenchService) parseStreamExecuteRequest(bodyBytes []byte) (sql string, datasourceID string, err error) { + var requestBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { + return "", "", fmt.Errorf("failed to unmarshal request body: %v", err) + } + + // 从 sql 字段获取 SQL + if sqlVal, ok := requestBody["sql"]; ok { + if sqlStr, ok := sqlVal.(string); ok { + sql = sqlStr + } + } + + // 从 sid 字段解析 datasource ID + // sid 格式: sid:{base64编码的JSON}:d:dms + // base64 JSON 包含: {"dbId":623,"dsId":28,"from":"192.168.21.47","logicalSession":false,"realId":"ee9b8ab276"} + if sidVal, ok := requestBody["sid"]; ok { + if sidStr, ok := sidVal.(string); ok { + dsId, parseErr := sqlWorkbenchService.parseSidToDatasourceID(sidStr) + if parseErr != nil { + sqlWorkbenchService.log.Debugf("Failed to parse sid to datasource ID: %v", parseErr) + } else { + datasourceID = dsId + } + } + } + + return sql, datasourceID, nil +} + +// parseSidToDatasourceID 从 sid 字符串中解析出 datasource ID +// sid 格式: sid:{base64编码的JSON}:d:dms +func (sqlWorkbenchService *SqlWorkbenchService) parseSidToDatasourceID(sid string) (string, error) { + // 检查 sid 格式: sid:...:d:dms + if !strings.HasPrefix(sid, "sid:") { + return "", fmt.Errorf("invalid sid format, missing 'sid:' prefix") + } + + // 移除 "sid:" 前缀 + sid = strings.TrimPrefix(sid, "sid:") + + // 查找最后一个 ":d:dms" 后缀并移除 + if idx := strings.LastIndex(sid, ":d:dms"); idx != -1 { + sid = sid[:idx] + } + + // 解码 base64 + decodedBytes, err := base64.StdEncoding.DecodeString(sid) + if err != nil { + return "", fmt.Errorf("failed to decode base64 sid: %v", err) + } + + // 解析 JSON + var sidData struct { + DbId int `json:"dbId"` + DsId int `json:"dsId"` + From string `json:"from"` + LogicalSession bool `json:"logicalSession"` + RealId string `json:"realId"` + } + + if err := json.Unmarshal(decodedBytes, &sidData); err != nil { + return "", fmt.Errorf("failed to unmarshal sid JSON: %v", err) + } + + // 返回 dsId 作为字符串 + return fmt.Sprintf("%d", sidData.DsId), nil +} + +// getDMSUserIdFromRequest 从请求中获取 DMS 用户 ID +func (sqlWorkbenchService *SqlWorkbenchService) getDMSUserIdFromRequest(c echo.Context) (string, error) { + var dmsToken string + for _, cookie := range c.Cookies() { + if cookie.Name == pkgConst.DMSToken { + dmsToken = cookie.Value + break + } + } + + if dmsToken == "" { + return "", fmt.Errorf("dms token is empty") + } + + dmsUserId, err := jwt.ParseUidFromJwtTokenStr(dmsToken) + if err != nil { + return "", fmt.Errorf("failed to parse dms user id from token: %v", err) + } + + return dmsUserId, nil +} + +// getDMSDBServiceIDFromCache 从 sql_workbench_datasource_caches 表获取 dms_db_service_id +func (sqlWorkbenchService *SqlWorkbenchService) getDMSDBServiceIDFromCache(ctx context.Context, datasourceID, dmsUserID string) (string, error) { + // 尝试将 datasourceID 转换为 int64(ODC 的 datasource ID 通常是数字) + var sqlWorkbenchDatasourceID int64 + if _, err := fmt.Sscanf(datasourceID, "%d", &sqlWorkbenchDatasourceID); err != nil { + // 如果转换失败,尝试直接使用字符串作为 datasource ID + sqlWorkbenchService.log.Debugf("Failed to convert datasourceID to int64, trying to find by string: %s", datasourceID) + } + + // 从缓存表中查找,需要根据 sql_workbench_datasource_id 和 dms_user_id 查找 + // 由于缓存表可能没有直接存储 sql_workbench_datasource_id,我们需要通过其他方式查找 + // 这里先尝试通过用户 ID 获取所有数据源,然后匹配 + datasources, err := sqlWorkbenchService.sqlWorkbenchDatasourceRepo.GetSqlWorkbenchDatasourcesByUserID(ctx, dmsUserID) + if err != nil { + return "", fmt.Errorf("failed to get datasources by user id: %v", err) + } + + // 如果 datasourceID 是数字,尝试匹配 SqlWorkbenchDatasourceID + if sqlWorkbenchDatasourceID > 0 { + for _, ds := range datasources { + if ds.SqlWorkbenchDatasourceID == sqlWorkbenchDatasourceID { + return ds.DMSDBServiceID, nil + } + } + } + + // 如果找不到,返回第一个匹配的数据源(临时方案,后续可能需要更精确的匹配逻辑) + if len(datasources) > 0 { + // 这里可以根据实际业务逻辑选择合适的数据源 + // 暂时返回第一个数据源的 dms_db_service_id + return datasources[0].DMSDBServiceID, nil + } + + return "", fmt.Errorf("no datasource found for datasourceID: %s, userID: %s", datasourceID, dmsUserID) +} + +// isEnableSQLAudit 检查是否启用 SQL 审核 +func (sqlWorkbenchService *SqlWorkbenchService) isEnableSQLAudit(dbService *biz.DBService) bool { + if dbService.SQLEConfig == nil || dbService.SQLEConfig.SQLQueryConfig == nil { + return false + } + return dbService.SQLEConfig.AuditEnabled && dbService.SQLEConfig.SQLQueryConfig.AuditEnabled +} + +// callSQLEAudit 调用 SQLE 直接审核接口 +func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Context, sql string, dbService *biz.DBService) (*auditSQLReply, error) { + // 获取 SQLE 服务地址 + target, err := sqlWorkbenchService.proxyTargetRepo.GetProxyTargetByName(ctx, _const.SqleComponentName) + if err != nil { + return nil, fmt.Errorf("failed to get SQLE proxy target: %v", err) + } + + sqleAddr := fmt.Sprintf("%s/v2/sql_audit", target.URL.String()) + + // 构建审核请求 + auditReq := auditSQLReq{ + InstanceType: dbService.DBType, + SQLContent: sql, + SQLType: "sql", + ProjectId: dbService.ProjectUID, + RuleTemplateName: dbService.SQLEConfig.SQLQueryConfig.RuleTemplateName, + } + + // 设置请求头 + header := map[string]string{ + "Authorization": pkgHttp.DefaultDMSToken, + } + + // 调用 SQLE 审核接口 + reply := &auditSQLReply{} + if err := pkgHttp.POST(ctx, sqleAddr, header, auditReq, reply); err != nil { + return nil, fmt.Errorf("failed to call SQLE audit API: %v", err) + } + + if reply.Code != 0 { + return nil, fmt.Errorf("SQLE audit API returned error code %v: %v", reply.Code, reply.Message) + } + + return reply, nil +} + +// interceptAndAddAuditResult 拦截响应并添加审核结果 +func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c echo.Context, next echo.HandlerFunc, auditResult *auditSQLReply, dbService *biz.DBService) error { + // 创建响应拦截器 + srw := newStreamExecuteResponseWriter(c) + cloudbeaverResBuf := srw.Buffer + c.Response().Writer = srw + + defer func() { + // 在 defer 中处理响应 + if srw.status != 0 { + srw.original.WriteHeader(srw.status) + } + + // 读取响应内容 + responseBytes := cloudbeaverResBuf.Bytes() + if len(responseBytes) == 0 { + return + } + + // 如果是 gzip 压缩响应,先解压 + responseBytes, wasGzip, err := sqlWorkbenchService.decodeResponseBody(cloudbeaverResBuf.Bytes(), srw.Header().Get("Content-Encoding")) + if err != nil { + sqlWorkbenchService.log.Debugf("Failed to decode response body, returning original response: %v", err) + srw.original.Write(cloudbeaverResBuf.Bytes()) + return + } + + // 如果解压过,先移除 Content-Encoding,后续根据需要重新设置 + if wasGzip { + srw.original.Header().Del("Content-Encoding") + } + + // 解析响应 JSON + var responseBody StreamExecuteResponse + if err := json.Unmarshal(responseBytes, &responseBody); err != nil { + sqlWorkbenchService.log.Debugf("Failed to unmarshal response, returning original response: %v", err) + // 如果解析失败,直接返回原始响应 + srw.original.Write(cloudbeaverResBuf.Bytes()) + return + } + + // 添加审核结果到响应的 data 字段中 + if auditResult != nil && auditResult.Data != nil && auditResult.Data.PassRate != 1 { + // 将 SQLE 审核结果整合到每个 SQL 条目中 + sqlWorkbenchService.mergeSQLEAuditResults(&responseBody.Data, auditResult.Data, dbService) + + // 在 data 级别添加汇总的审核结果信息 + responseBody.Data.SQLEAuditResult = &SQLEAuditResultSummary{ + AuditLevel: auditResult.Data.AuditLevel, + Score: auditResult.Data.Score, + PassRate: auditResult.Data.PassRate, + } + } + + // 重新序列化响应 + modifiedResponse, err := json.Marshal(responseBody) + if err != nil { + sqlWorkbenchService.log.Errorf("Failed to marshal modified response: %v", err) + srw.original.Write(cloudbeaverResBuf.Bytes()) + return + } + + finalResponse := modifiedResponse + if wasGzip { + encoded, err := sqlWorkbenchService.encodeResponseBody(modifiedResponse) + if err != nil { + sqlWorkbenchService.log.Errorf("Failed to re-encode gzip response: %v", err) + srw.original.Write(cloudbeaverResBuf.Bytes()) + return + } + finalResponse = encoded + srw.original.Header().Set("Content-Encoding", "gzip") + } + + // 更新 Content-Length + header := srw.original.Header() + header.Set("Content-Length", fmt.Sprintf("%d", len(finalResponse))) + + // 如果拦截过程中未显式写入状态码,默认使用 200 + statusCode := srw.status + if statusCode == 0 { + statusCode = http.StatusOK + } + srw.original.WriteHeader(statusCode) + + // 写入修改后的响应 + if _, err := srw.original.Write(finalResponse); err != nil { + sqlWorkbenchService.log.Errorf("Failed to write modified response: %v", err) + } + }() + + // 执行下一个处理器 + return next(c) +} + +// decodeResponseBody 根据 Content-Encoding 判断是否需要解压 +func (sqlWorkbenchService *SqlWorkbenchService) decodeResponseBody(body []byte, contentEncoding string) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + encoding := strings.ToLower(contentEncoding) + isGzip := strings.Contains(encoding, "gzip") || (len(body) >= 2 && body[0] == 0x1f && body[1] == 0x8b) + if !isGzip { + return body, false, nil + } + + gzipReader, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + return nil, true, fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzipReader.Close() + + decompressed, err := io.ReadAll(gzipReader) + if err != nil { + return nil, true, fmt.Errorf("failed to decompress gzip body: %w", err) + } + + return decompressed, true, nil +} + +// encodeResponseBody 将响应体按照 gzip 编码 +func (sqlWorkbenchService *SqlWorkbenchService) encodeResponseBody(body []byte) ([]byte, error) { + var buf bytes.Buffer + gzipWriter := gzip.NewWriter(&buf) + if _, err := gzipWriter.Write(body); err != nil { + gzipWriter.Close() + return nil, fmt.Errorf("failed to gzip response body: %w", err) + } + if err := gzipWriter.Close(); err != nil { + return nil, fmt.Errorf("failed to finalize gzip response body: %w", err) + } + return buf.Bytes(), nil +} + +// mergeSQLEAuditResults 将 SQLE 审核结果整合到 sqls 数组中 +func (sqlWorkbenchService *SqlWorkbenchService) mergeSQLEAuditResults(data *StreamExecuteData, auditData *auditResDataV2, dbService *biz.DBService) { + // 创建 SQL 到审核结果的映射 + sqlAuditMap := make(map[string]*auditSQLResV2) + for i := range auditData.SQLResults { + sqlResult := &auditData.SQLResults[i] + // 使用 exec_sql 作为 key,去除首尾空格和分号 + normalizedSQL := strings.TrimSpace(strings.TrimSuffix(sqlResult.ExecSQL, ";")) + sqlAuditMap[normalizedSQL] = sqlResult + } + + // 获取数据源设置的审核放行等级 + allowQueryWhenLessThanAuditLevel := dbService.GetAllowQueryWhenLessThanAuditLevel() + + // 判断是否需要审批 + needApproval := sqlWorkbenchService.shouldRequireApproval(auditData.SQLResults, allowQueryWhenLessThanAuditLevel) + + // 设置 ApprovalRequired 字段 + data.ApprovalRequired = needApproval + + // 遍历 sqls 数组,为每个 SQL 条目添加 SQLE 审核结果 + for i := range data.SQLs { + sqlItem := &data.SQLs[i] + + // 尝试从 originalSql 或 executedSql 匹配审核结果 + var matchedAuditResult *auditSQLResV2 + normalizedSQL := strings.TrimSpace(strings.TrimSuffix(sqlItem.SQLTuple.OriginalSQL, ";")) + if auditResult, found := sqlAuditMap[normalizedSQL]; found { + matchedAuditResult = auditResult + } + + if matchedAuditResult == nil { + normalizedSQL = strings.TrimSpace(strings.TrimSuffix(sqlItem.SQLTuple.ExecutedSQL, ";")) + if auditResult, found := sqlAuditMap[normalizedSQL]; found { + matchedAuditResult = auditResult + } + } + + // 如果找到匹配的审核结果,将其转换为 violatedRules 格式并添加 + if matchedAuditResult != nil { + sqleViolatedRules := sqlWorkbenchService.convertSQLEAuditToViolatedRules(matchedAuditResult) + if len(sqleViolatedRules) > 0 { + sqlItem.ViolatedRules = sqleViolatedRules + } + } + } +} + +// shouldRequireApproval 根据审核放行等级判断是否需要审批 +func (sqlWorkbenchService *SqlWorkbenchService) shouldRequireApproval(sqlResults []auditSQLResV2, allowQueryWhenLessThanAuditLevel string) bool { + // 如果没有设置审核放行等级,默认需要审批 + if allowQueryWhenLessThanAuditLevel == "" { + return true + } + + // 遍历所有 SQL 审核结果 + for _, sqlResult := range sqlResults { + // 检查是否有执行失败的审核项 + for _, auditItem := range sqlResult.AuditResult { + if auditItem.ExecutionFailed { + return true + } + } + + // 比较审核等级:如果 SQL 的审核等级大于允许的等级,则需要审批 + // 使用 RuleLevel 的 LessOrEqual 方法进行比较 + sqlAuditLevel := dbmodel.RuleLevel(sqlResult.AuditLevel) + allowedLevel := dbmodel.RuleLevel(allowQueryWhenLessThanAuditLevel) + + // 如果 SQL 的审核等级大于允许的等级(即 !LessOrEqual),则需要审批 + if !sqlAuditLevel.LessOrEqual(allowedLevel) { + return true + } + } + + // 所有 SQL 的审核等级都小于等于允许的等级,不需要审批 + return false +} + +// convertSQLEAuditToViolatedRules 将 SQLE 审核结果转换为 violatedRules 格式 +func (sqlWorkbenchService *SqlWorkbenchService) convertSQLEAuditToViolatedRules(auditResult *auditSQLResV2) []StreamExecuteRule { + var violatedRules []StreamExecuteRule + + // 将 SQLE 的 audit_result 转换为 violatedRules 格式 + for _, auditItem := range auditResult.AuditResult { + // 映射 level 字符串到数字 + levelNum := sqlWorkbenchService.mapAuditLevelToNumber(auditItem.Level) + + violatedRule := StreamExecuteRule{ + AppliedDialectTypes: nil, + CreateTime: nil, + Enabled: nil, + ID: nil, + Level: levelNum, + Metadata: nil, + OrganizationID: nil, + Properties: nil, + RulesetID: nil, + UpdateTime: nil, + Violation: StreamExecuteViolation{ + Level: levelNum, + LocalizedMessage: auditItem.Message, + Offset: 0, // SQLE 审核结果可能没有 offset 信息 + Start: 0, + Stop: 0, + Text: auditResult.ExecSQL, + }, + } + violatedRules = append(violatedRules, violatedRule) + } + + return violatedRules +} + +// mapAuditLevelToNumber 将审核级别字符串映射到数字 +// normal=0, notice=1, warn=2, error=3 +func (sqlWorkbenchService *SqlWorkbenchService) mapAuditLevelToNumber(level string) int { + switch strings.ToLower(level) { + case "normal": + return 0 + case "notice": + return 1 + case "warn": + return 2 + case "error": + return 3 + default: + return 1 // 默认为 notice + } +} + +// StreamExecuteResponse streamExecute 接口响应结构 +type StreamExecuteResponse struct { + Data StreamExecuteData `json:"data"` + DurationMillis int64 `json:"durationMillis"` + HTTPStatus string `json:"httpStatus"` + RequestID string `json:"requestId"` + Server string `json:"server"` + Successful bool `json:"successful"` + Timestamp float64 `json:"timestamp"` + TraceID string `json:"traceId"` +} + +// StreamExecuteData streamExecute 响应中的 data 字段 +type StreamExecuteData struct { + ApprovalRequired bool `json:"approvalRequired"` + LogicalSQL bool `json:"logicalSql"` + RequestID *string `json:"requestId"` + SQLs []StreamExecuteSQLItem `json:"sqls"` + UnauthorizedDBResources interface{} `json:"unauthorizedDBResources"` + ViolatedRules []interface{} `json:"violatedRules"` + SQLEAuditResult *SQLEAuditResultSummary `json:"sqleAuditResult,omitempty"` +} + +// StreamExecuteSQLItem SQL 条目 +type StreamExecuteSQLItem struct { + SQLTuple StreamExecuteSQLTuple `json:"sqlTuple"` + ViolatedRules []StreamExecuteRule `json:"violatedRules"` +} + +// StreamExecuteSQLTuple SQL 元组 +type StreamExecuteSQLTuple struct { + ExecutedSQL string `json:"executedSql"` + Offset int `json:"offset"` + OriginalSQL string `json:"originalSql"` + SQLID string `json:"sqlId"` +} + +// StreamExecuteRule 违反的规则 +type StreamExecuteRule struct { + AppliedDialectTypes interface{} `json:"appliedDialectTypes"` + CreateTime interface{} `json:"createTime"` + Enabled interface{} `json:"enabled"` + ID interface{} `json:"id"` + Level int `json:"level"` + Metadata interface{} `json:"metadata"` + OrganizationID interface{} `json:"organizationId"` + Properties interface{} `json:"properties"` + RulesetID interface{} `json:"rulesetId"` + UpdateTime interface{} `json:"updateTime"` + Violation StreamExecuteViolation `json:"violation"` +} + +// StreamExecuteViolation 违反详情 +type StreamExecuteViolation struct { + Level int `json:"level"` + LocalizedMessage string `json:"localizedMessage"` + Offset int `json:"offset"` + Start int `json:"start"` + Stop int `json:"stop"` + Text string `json:"text"` +} + +// SQLEAuditResultSummary SQLE 审核结果汇总 +type SQLEAuditResultSummary struct { + AuditLevel string `json:"audit_level"` + Score int32 `json:"score"` + PassRate float64 `json:"pass_rate"` +} + +// streamExecuteResponseWriter 响应拦截器,用于捕获响应内容 +type streamExecuteResponseWriter struct { + echo.Response + Buffer *bytes.Buffer + original http.ResponseWriter + status int +} + +func newStreamExecuteResponseWriter(c echo.Context) *streamExecuteResponseWriter { + buf := new(bytes.Buffer) + return &streamExecuteResponseWriter{ + Response: *c.Response(), + Buffer: buf, + original: c.Response().Writer, + } +} + +func (w *streamExecuteResponseWriter) Write(b []byte) (int, error) { + // 如果未设置状态码,则补默认值 + if w.status == 0 { + w.WriteHeader(http.StatusOK) + } + // 写入 buffer,不立即写给客户端 + return w.Buffer.Write(b) +} + +func (w *streamExecuteResponseWriter) WriteHeader(code int) { + w.status = code +} + +// auditSQLReq SQLE 审核请求结构 +type auditSQLReq struct { + InstanceType string `json:"instance_type"` + SQLContent string `json:"sql_content"` + SQLType string `json:"sql_type"` + ProjectId string `json:"project_id"` + RuleTemplateName string `json:"rule_template_name"` +} + +// auditSQLReply SQLE 审核响应结构 +type auditSQLReply struct { + Code int `json:"code"` + Message string `json:"message"` + Data *auditResDataV2 `json:"data"` +} + +// auditResDataV2 审核结果数据 +type auditResDataV2 struct { + AuditLevel string `json:"audit_level"` + Score int32 `json:"score"` + PassRate float64 `json:"pass_rate"` + SQLResults []auditSQLResV2 `json:"sql_results"` +} + +// auditSQLResV2 单个 SQL 审核结果 +type auditSQLResV2 struct { + Number uint `json:"number"` + ExecSQL string `json:"exec_sql"` + AuditResult []struct { + Level string `json:"level"` + Message string `json:"message"` + ExecutionFailed bool `json:"execution_failed"` + } `json:"audit_result"` + AuditLevel string `json:"audit_level"` +} From 004b1cb9d5950ff2864045fa608fcdc1ac754458 Mon Sep 17 00:00:00 2001 From: littleniannian Date: Thu, 4 Dec 2025 14:25:37 +0800 Subject: [PATCH 2/6] fix: sql workbench sql audit --- internal/apiserver/service/router.go | 2 +- .../service/sql_workbench_service.go | 101 +++++++++++++----- 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/internal/apiserver/service/router.go b/internal/apiserver/service/router.go index b37c159b..b9149d13 100644 --- a/internal/apiserver/service/router.go +++ b/internal/apiserver/service/router.go @@ -285,7 +285,7 @@ func (s *APIServer) initRouter() error { SqlWorkbenchService: s.SqlWorkbenchController.SqlWorkbenchService, })) - sqlWorkbenchV1.Use(s.SqlWorkbenchController.SqlWorkbenchService.Intercept()) + sqlWorkbenchV1.Use(s.SqlWorkbenchController.SqlWorkbenchService.AuditMiddleware()) sqlWorkbenchV1.Use(middleware.ProxyWithConfig(middleware.ProxyConfig{ Skipper: middleware.DefaultSkipper, Balancer: middleware.NewRandomBalancer(targets), diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index 7b6d6bbb..d09bfc38 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -12,6 +12,7 @@ import ( "net/url" "strings" "sync" + "time" dmsV1 "github.com/actiontech/dms/api/dms/service/v1" "github.com/actiontech/dms/internal/apiserver/conf" @@ -994,8 +995,8 @@ func makeHttpRequest(ctx context.Context, url string, headers map[string]string, return nil } -// intercept 拦截工作台odc请求进行加工 -func (sqlWorkbenchService *SqlWorkbenchService) Intercept() echo.MiddlewareFunc { +// AuditMiddleware 拦截工作台odc请求进行加工 +func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { // 只拦截包含 /streamExecute 的请求 @@ -1111,8 +1112,8 @@ func (sqlWorkbenchService *SqlWorkbenchService) parseSidToDatasourceID(sid strin // 移除 "sid:" 前缀 sid = strings.TrimPrefix(sid, "sid:") - // 查找最后一个 ":d:dms" 后缀并移除 - if idx := strings.LastIndex(sid, ":d:dms"); idx != -1 { + // 查找最后一个 ":d" 后缀并移除从 ":d" 开始的所有字符 + if idx := strings.LastIndex(sid, ":d"); idx != -1 { sid = sid[:idx] } @@ -1244,6 +1245,67 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex // interceptAndAddAuditResult 拦截响应并添加审核结果 func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c echo.Context, next echo.HandlerFunc, auditResult *auditSQLReply, dbService *biz.DBService) error { + // 判断是否需要审批 + allowQueryWhenLessThanAuditLevel := dbService.GetAllowQueryWhenLessThanAuditLevel() + needApproval := sqlWorkbenchService.shouldRequireApproval(auditResult.Data.SQLResults, allowQueryWhenLessThanAuditLevel) + + // 如果需要审批,直接返回审核结果,不请求真实的 streamExecute 接口 + if needApproval { + return sqlWorkbenchService.buildAuditResponseWithoutExecution(c, auditResult, dbService) + } + + // 不需要审批,执行真实请求并添加审核结果 + return sqlWorkbenchService.executeAndAddAuditResult(c, next, auditResult, dbService) +} + +// buildAuditResponseWithoutExecution 构造审核响应,不执行真实的 SQL +func (sqlWorkbenchService *SqlWorkbenchService) buildAuditResponseWithoutExecution(c echo.Context, auditResult *auditSQLReply, dbService *biz.DBService) error { + // 构造 SQL 条目列表 + sqlItems := make([]StreamExecuteSQLItem, 0, len(auditResult.Data.SQLResults)) + for _, sqlResult := range auditResult.Data.SQLResults { + // 转换审核结果为 violatedRules 格式 + violatedRules := sqlWorkbenchService.convertSQLEAuditToViolatedRules(&sqlResult) + + sqlItem := StreamExecuteSQLItem{ + SQLTuple: StreamExecuteSQLTuple{ + ExecutedSQL: sqlResult.ExecSQL, + Offset: int(sqlResult.Number), + OriginalSQL: sqlResult.ExecSQL, + SQLID: fmt.Sprintf("sqle-audit-%d", sqlResult.Number), + }, + ViolatedRules: violatedRules, + } + sqlItems = append(sqlItems, sqlItem) + } + + // 构造响应数据 + responseData := StreamExecuteData{ + ApprovalRequired: true, // 需要审批 + LogicalSQL: false, + RequestID: nil, + SQLs: sqlItems, + UnauthorizedDBResources: nil, + ViolatedRules: []interface{}{}, + } + + // 构造完整响应 + response := StreamExecuteResponse{ + Data: responseData, + DurationMillis: 0, + HTTPStatus: "OK", + RequestID: fmt.Sprintf("dms-audit-%d", time.Now().UnixNano()), + Server: "DMS", + Successful: true, + Timestamp: float64(time.Now().Unix()), + TraceID: c.Response().Header().Get("X-Trace-ID"), + } + + // 返回 JSON 响应 + return c.JSON(http.StatusOK, response) +} + +// executeAndAddAuditResult 执行真实请求并添加审核结果 +func (sqlWorkbenchService *SqlWorkbenchService) executeAndAddAuditResult(c echo.Context, next echo.HandlerFunc, auditResult *auditSQLReply, dbService *biz.DBService) error { // 创建响应拦截器 srw := newStreamExecuteResponseWriter(c) cloudbeaverResBuf := srw.Buffer @@ -1287,13 +1349,6 @@ func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c ech if auditResult != nil && auditResult.Data != nil && auditResult.Data.PassRate != 1 { // 将 SQLE 审核结果整合到每个 SQL 条目中 sqlWorkbenchService.mergeSQLEAuditResults(&responseBody.Data, auditResult.Data, dbService) - - // 在 data 级别添加汇总的审核结果信息 - responseBody.Data.SQLEAuditResult = &SQLEAuditResultSummary{ - AuditLevel: auditResult.Data.AuditLevel, - Score: auditResult.Data.Score, - PassRate: auditResult.Data.PassRate, - } } // 重新序列化响应 @@ -1388,14 +1443,9 @@ func (sqlWorkbenchService *SqlWorkbenchService) mergeSQLEAuditResults(data *Stre sqlAuditMap[normalizedSQL] = sqlResult } - // 获取数据源设置的审核放行等级 - allowQueryWhenLessThanAuditLevel := dbService.GetAllowQueryWhenLessThanAuditLevel() - - // 判断是否需要审批 - needApproval := sqlWorkbenchService.shouldRequireApproval(auditData.SQLResults, allowQueryWhenLessThanAuditLevel) - - // 设置 ApprovalRequired 字段 - data.ApprovalRequired = needApproval + // 设置 ApprovalRequired 为 false,表示已通过审核可以执行 + // 注意:如果需要审批,在 interceptAndAddAuditResult 中已经拦截,不会执行到这里 + data.ApprovalRequired = false // 遍历 sqls 数组,为每个 SQL 条目添加 SQLE 审核结果 for i := range data.SQLs { @@ -1522,13 +1572,12 @@ type StreamExecuteResponse struct { // StreamExecuteData streamExecute 响应中的 data 字段 type StreamExecuteData struct { - ApprovalRequired bool `json:"approvalRequired"` - LogicalSQL bool `json:"logicalSql"` - RequestID *string `json:"requestId"` - SQLs []StreamExecuteSQLItem `json:"sqls"` - UnauthorizedDBResources interface{} `json:"unauthorizedDBResources"` - ViolatedRules []interface{} `json:"violatedRules"` - SQLEAuditResult *SQLEAuditResultSummary `json:"sqleAuditResult,omitempty"` + ApprovalRequired bool `json:"approvalRequired"` + LogicalSQL bool `json:"logicalSql"` + RequestID *string `json:"requestId"` + SQLs []StreamExecuteSQLItem `json:"sqls"` + UnauthorizedDBResources interface{} `json:"unauthorizedDBResources"` + ViolatedRules []interface{} `json:"violatedRules"` } // StreamExecuteSQLItem SQL 条目 From f31fa0f69aed752bf7488dbb58bb14d0f41b24d8 Mon Sep 17 00:00:00 2001 From: littleniannian Date: Thu, 4 Dec 2025 16:03:38 +0800 Subject: [PATCH 3/6] fix: map audit rule level --- .../sql_workbench/service/sql_workbench_service.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index d09bfc38..444fbf46 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -1477,9 +1477,9 @@ func (sqlWorkbenchService *SqlWorkbenchService) mergeSQLEAuditResults(data *Stre // shouldRequireApproval 根据审核放行等级判断是否需要审批 func (sqlWorkbenchService *SqlWorkbenchService) shouldRequireApproval(sqlResults []auditSQLResV2, allowQueryWhenLessThanAuditLevel string) bool { - // 如果没有设置审核放行等级,默认需要审批 + // 如果没有设置审核放行等级,那么直接放行 if allowQueryWhenLessThanAuditLevel == "" { - return true + return false } // 遍历所有 SQL 审核结果 @@ -1548,13 +1548,13 @@ func (sqlWorkbenchService *SqlWorkbenchService) mapAuditLevelToNumber(level stri case "normal": return 0 case "notice": - return 1 + return 3 case "warn": - return 2 + return 1 case "error": - return 3 + return 2 default: - return 1 // 默认为 notice + return 0 // 默认为 notice } } From 1e9542d33fbf1d69a3d820dce16c4885ac867d8e Mon Sep 17 00:00:00 2001 From: littleniannian Date: Thu, 4 Dec 2025 18:38:32 +0800 Subject: [PATCH 4/6] fix: code opt --- .../service/sql_workbench_service.go | 48 ++++++------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index 444fbf46..26edc5e9 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -7,6 +7,8 @@ import ( "encoding/base64" "encoding/json" "fmt" + "github.com/actiontech/dms/internal/pkg/cloudbeaver" + "github.com/actiontech/dms/internal/pkg/utils" "io" "net/http" "net/url" @@ -1007,8 +1009,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar // 读取请求体 bodyBytes, err := io.ReadAll(c.Request().Body) if err != nil { - sqlWorkbenchService.log.Errorf("Failed to read request body: %v", err) - return next(c) + return fmt.Errorf("failed to read request body: %w", err) } // 恢复请求体,供后续处理使用 c.Request().Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -1016,52 +1017,44 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar // 解析请求体获取 SQL 和 datasource ID sql, datasourceID, err := sqlWorkbenchService.parseStreamExecuteRequest(bodyBytes) if err != nil { - sqlWorkbenchService.log.Debugf("Failed to parse streamExecute request, skipping audit: %v", err) - return next(c) + return fmt.Errorf("failed to parse streamExecute request, skipping audit: %v", err) } if sql == "" || datasourceID == "" { - sqlWorkbenchService.log.Debugf("SQL or datasource ID is empty, skipping audit") - return next(c) + return fmt.Errorf("SQL or datasource ID is empty, skipping audit") } // 获取当前用户 ID dmsUserId, err := sqlWorkbenchService.getDMSUserIdFromRequest(c) if err != nil { - sqlWorkbenchService.log.Errorf("Failed to get DMS user ID: %v", err) - return next(c) + return fmt.Errorf("failed to get DMS user ID: %v", err) } // 从缓存表获取 dms_db_service_id dmsDBServiceID, err := sqlWorkbenchService.getDMSDBServiceIDFromCache(c.Request().Context(), datasourceID, dmsUserId) if err != nil { - sqlWorkbenchService.log.Errorf("Failed to get dms_db_service_id from cache: %v", err) - return next(c) + return fmt.Errorf("failed to get dms_db_service_id from cache: %v", err) } if dmsDBServiceID == "" { - sqlWorkbenchService.log.Debugf("dms_db_service_id not found in cache for datasource: %s", datasourceID) - return next(c) + return fmt.Errorf("dms_db_service_id not found in cache for datasource: %s", datasourceID) } // 获取 DBService 信息 dbService, err := sqlWorkbenchService.dbServiceUsecase.GetDBService(c.Request().Context(), dmsDBServiceID) if err != nil { - sqlWorkbenchService.log.Errorf("Failed to get DBService: %v", err) - return next(c) + return fmt.Errorf("failed to get DBService: %v", err) } // 检查是否启用 SQL 审核 if !sqlWorkbenchService.isEnableSQLAudit(dbService) { - sqlWorkbenchService.log.Debugf("SQL audit is not enabled for DBService: %s", dmsDBServiceID) - return next(c) + return fmt.Errorf("SQL audit is not enabled for DBService: %s", dmsDBServiceID) } // 调用 SQLE 审核接口 auditResult, err := sqlWorkbenchService.callSQLEAudit(c.Request().Context(), sql, dbService) if err != nil { - sqlWorkbenchService.log.Errorf("Failed to call SQLE audit: %v", err) - return next(c) + return fmt.Errorf("call SQLE audit failed: %v", err) } // 拦截响应并添加审核结果 @@ -1217,7 +1210,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex sqleAddr := fmt.Sprintf("%s/v2/sql_audit", target.URL.String()) // 构建审核请求 - auditReq := auditSQLReq{ + auditReq := cloudbeaver.AuditSQLReq{ InstanceType: dbService.DBType, SQLContent: sql, SQLType: "sql", @@ -1397,9 +1390,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) decodeResponseBody(body []byte, if len(body) == 0 { return body, false, nil } - - encoding := strings.ToLower(contentEncoding) - isGzip := strings.Contains(encoding, "gzip") || (len(body) >= 2 && body[0] == 0x1f && body[1] == 0x8b) + isGzip := utils.IsGzip(body) if !isGzip { return body, false, nil } @@ -1542,7 +1533,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) convertSQLEAuditToViolatedRules( } // mapAuditLevelToNumber 将审核级别字符串映射到数字 -// normal=0, notice=1, warn=2, error=3 +// normal=0, notice=3, warn=1, error=2 func (sqlWorkbenchService *SqlWorkbenchService) mapAuditLevelToNumber(level string) int { switch strings.ToLower(level) { case "normal": @@ -1554,7 +1545,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) mapAuditLevelToNumber(level stri case "error": return 2 default: - return 0 // 默认为 notice + return 0 // 默认为 normal } } @@ -1656,15 +1647,6 @@ func (w *streamExecuteResponseWriter) WriteHeader(code int) { w.status = code } -// auditSQLReq SQLE 审核请求结构 -type auditSQLReq struct { - InstanceType string `json:"instance_type"` - SQLContent string `json:"sql_content"` - SQLType string `json:"sql_type"` - ProjectId string `json:"project_id"` - RuleTemplateName string `json:"rule_template_name"` -} - // auditSQLReply SQLE 审核响应结构 type auditSQLReply struct { Code int `json:"code"` From 83b74ee1274db0fa8a105065bf2da6a1b41bc78a Mon Sep 17 00:00:00 2001 From: littleniannian Date: Thu, 4 Dec 2025 19:32:34 +0800 Subject: [PATCH 5/6] fix: save audit failed operation log --- .../service/sql_workbench_service.go | 120 +++++++++++++++++- 1 file changed, 114 insertions(+), 6 deletions(-) diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index 26edc5e9..1b5e755b 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -7,8 +7,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/actiontech/dms/internal/pkg/cloudbeaver" - "github.com/actiontech/dms/internal/pkg/utils" "io" "net/http" "net/url" @@ -16,6 +14,9 @@ import ( "sync" "time" + "github.com/actiontech/dms/internal/pkg/cloudbeaver" + "github.com/actiontech/dms/internal/pkg/utils" + dmsV1 "github.com/actiontech/dms/api/dms/service/v1" "github.com/actiontech/dms/internal/apiserver/conf" "github.com/actiontech/dms/internal/dms/biz" @@ -25,11 +26,14 @@ import ( "github.com/actiontech/dms/internal/sql_workbench/client" config "github.com/actiontech/dms/internal/sql_workbench/config" "github.com/actiontech/dms/pkg/dms-common/api/jwt" + "github.com/actiontech/dms/pkg/dms-common/i18nPkg" _const "github.com/actiontech/dms/pkg/dms-common/pkg/const" pkgHttp "github.com/actiontech/dms/pkg/dms-common/pkg/http" utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log" + pkgRand "github.com/actiontech/dms/pkg/rand" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "golang.org/x/text/language" ) const SQL_WORKBENCH_URL = "/odc_query" @@ -95,6 +99,7 @@ type SqlWorkbenchService struct { sqlWorkbenchUserRepo biz.SqlWorkbenchUserRepo sqlWorkbenchDatasourceRepo biz.SqlWorkbenchDatasourceRepo proxyTargetRepo biz.ProxyTargetRepo + cbOperationLogUsecase *biz.CbOperationLogUsecase } func NewAndInitSqlWorkbenchService(logger utilLog.Logger, opts *conf.DMSOptions) (*SqlWorkbenchService, error) { @@ -162,6 +167,10 @@ func NewAndInitSqlWorkbenchService(logger utilLog.Logger, opts *conf.DMSOptions) sqlWorkbenchUserRepo := storage.NewSqlWorkbenchRepo(logger, st) sqlWorkbenchDatasourceRepo := storage.NewSqlWorkbenchDatasourceRepo(logger, st) + // 初始化操作日志相关 + cbOperationLogRepo := storage.NewCbOperationLogRepo(logger, st) + cbOperationLogUsecase := biz.NewCbOperationLogUsecase(logger, cbOperationLogRepo, opPermissionVerifyUsecase, proxyTargetRepo, biz.NewSystemVariableUsecase(logger, storage.NewSystemVariableRepo(logger, st))) + return &SqlWorkbenchService{ cfg: opts.SqlWorkBenchOpts, log: utilLog.NewHelper(logger, utilLog.WithMessageKey("sql_workbench.service")), @@ -173,6 +182,7 @@ func NewAndInitSqlWorkbenchService(logger utilLog.Logger, opts *conf.DMSOptions) sqlWorkbenchUserRepo: sqlWorkbenchUserRepo, sqlWorkbenchDatasourceRepo: sqlWorkbenchDatasourceRepo, proxyTargetRepo: proxyTargetRepo, + cbOperationLogUsecase: cbOperationLogUsecase, }, nil } @@ -1058,7 +1068,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar } // 拦截响应并添加审核结果 - return sqlWorkbenchService.interceptAndAddAuditResult(c, next, auditResult, dbService) + return sqlWorkbenchService.interceptAndAddAuditResult(c, next, dmsUserId, auditResult, dbService) } } } @@ -1237,14 +1247,14 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex } // interceptAndAddAuditResult 拦截响应并添加审核结果 -func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c echo.Context, next echo.HandlerFunc, auditResult *auditSQLReply, dbService *biz.DBService) error { +func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c echo.Context, next echo.HandlerFunc, userId string, auditResult *auditSQLReply, dbService *biz.DBService) error { // 判断是否需要审批 allowQueryWhenLessThanAuditLevel := dbService.GetAllowQueryWhenLessThanAuditLevel() needApproval := sqlWorkbenchService.shouldRequireApproval(auditResult.Data.SQLResults, allowQueryWhenLessThanAuditLevel) // 如果需要审批,直接返回审核结果,不请求真实的 streamExecute 接口 if needApproval { - return sqlWorkbenchService.buildAuditResponseWithoutExecution(c, auditResult, dbService) + return sqlWorkbenchService.buildAuditResponseWithoutExecution(c, userId, auditResult, dbService) } // 不需要审批,执行真实请求并添加审核结果 @@ -1252,7 +1262,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c ech } // buildAuditResponseWithoutExecution 构造审核响应,不执行真实的 SQL -func (sqlWorkbenchService *SqlWorkbenchService) buildAuditResponseWithoutExecution(c echo.Context, auditResult *auditSQLReply, dbService *biz.DBService) error { +func (sqlWorkbenchService *SqlWorkbenchService) buildAuditResponseWithoutExecution(c echo.Context, userId string, auditResult *auditSQLReply, dbService *biz.DBService) error { // 构造 SQL 条目列表 sqlItems := make([]StreamExecuteSQLItem, 0, len(auditResult.Data.SQLResults)) for _, sqlResult := range auditResult.Data.SQLResults { @@ -1293,10 +1303,108 @@ func (sqlWorkbenchService *SqlWorkbenchService) buildAuditResponseWithoutExecuti TraceID: c.Response().Header().Get("X-Trace-ID"), } + // 记录审核拦截的操作日志 + if err := sqlWorkbenchService.saveAuditBlockedLog(c, userId, auditResult, dbService); err != nil { + sqlWorkbenchService.log.Errorf("failed to save audit blocked log: %v", err) + // 日志记录失败不影响响应返回 + } + // 返回 JSON 响应 return c.JSON(http.StatusOK, response) } +// saveAuditBlockedLog 记录审核拦截的操作日志 +func (sqlWorkbenchService *SqlWorkbenchService) saveAuditBlockedLog(c echo.Context, userId string, auditResult *auditSQLReply, dbService *biz.DBService) error { + ctx := context.Background() + + // 提取 session ID + sessionID := extractSessionID(c.Request().URL.Path) + + // 对每个被拦截的 SQL 记录日志 + for _, sqlResult := range auditResult.Data.SQLResults { + // 生成 UID + uid, err := pkgRand.GenStrUid() + if err != nil { + sqlWorkbenchService.log.Errorf("failed to generate UID: %v", err) + continue + } + + // 获取 SQL 内容 + sql := sqlResult.ExecSQL + + // 构建操作详情 + opDetail := i18nPkg.ConvertStr2I18nAsDefaultLang(sql) + + // 转换审核结果 + auditResults := sqlWorkbenchService.convertToAuditResults(&sqlResult) + + // 判断是否审核通过 + isAuditPass := false // 被拦截的 SQL 默认未通过审核 + + // 创建操作日志 - 标记为审核失败(被拦截) + now := time.Now() + cbOperationLog := biz.CbOperationLog{ + UID: uid, + OpPersonUID: userId, + OpTime: &now, + DBServiceUID: dbService.UID, + OpType: biz.CbOperationLogTypeSql, + I18nOpDetail: opDetail, + OpSessionID: &sessionID, + ProjectID: dbService.ProjectUID, + OpHost: c.RealIP(), + AuditResults: auditResults, + IsAuditPass: &isAuditPass, + ExecResult: "", + ExecTotalSec: 0, + ResultSetRowCount: 0, + } + + // 保存操作日志 + if err := sqlWorkbenchService.cbOperationLogUsecase.SaveCbOperationLog(ctx, &cbOperationLog); err != nil { + sqlWorkbenchService.log.Errorf("failed to save operation log: %v", err) + // 继续处理其他 SQL + } + } + + return nil +} + +// convertToAuditResults 将审核结果转换为 model.AuditResults 格式 +func (sqlWorkbenchService *SqlWorkbenchService) convertToAuditResults(sqlResult *auditSQLResV2) dbmodel.AuditResults { + var auditResults dbmodel.AuditResults + for _, result := range sqlResult.AuditResult { + auditResult := dbmodel.AuditResult{ + Level: result.Level, + RuleName: "", // SQLE 返回的审核结果中没有 RuleName,如果需要可以从 Message 中提取 + ExecutionFailed: result.ExecutionFailed, + I18nAuditResultInfo: dbmodel.I18nAuditResultInfo{ + language.Chinese: dbmodel.AuditResultInfo{ + Message: result.Message, + ErrorInfo: "", + }, + }, + } + auditResults = append(auditResults, auditResult) + } + return auditResults +} + +// extractSessionID 从路径中提取 session ID +func extractSessionID(path string) string { + // 匹配类似: /api/v2/datasource/sessions/sid:{sessionId}/sqls/getMoreResults + parts := strings.Split(path, "/") + for i, part := range parts { + if strings.HasPrefix(part, "sid:") { + return part + } + if i < len(parts)-1 && part == "sessions" && strings.HasPrefix(parts[i+1], "sid:") { + return parts[i+1] + } + } + return "" +} + // executeAndAddAuditResult 执行真实请求并添加审核结果 func (sqlWorkbenchService *SqlWorkbenchService) executeAndAddAuditResult(c echo.Context, next echo.HandlerFunc, auditResult *auditSQLReply, dbService *biz.DBService) error { // 创建响应拦截器 From ff96645409feb40b5e194bbfc483b196c2cf20ba Mon Sep 17 00:00:00 2001 From: littleniannian Date: Fri, 5 Dec 2025 12:36:41 +0800 Subject: [PATCH 6/6] fix: decode gzip by utils and remove auditSQLReply --- internal/dms/storage/model/model.go | 1 + internal/pkg/cloudbeaver/graphql.go | 4 +- .../service/sql_workbench_service.go | 63 +++++-------------- 3 files changed, 18 insertions(+), 50 deletions(-) diff --git a/internal/dms/storage/model/model.go b/internal/dms/storage/model/model.go index 89b337d4..4f912e11 100644 --- a/internal/dms/storage/model/model.go +++ b/internal/dms/storage/model/model.go @@ -569,6 +569,7 @@ type AuditResult struct { RuleName string `json:"rule_name"` ExecutionFailed bool `json:"execution_failed"` I18nAuditResultInfo I18nAuditResultInfo `json:"i18n_audit_result_info"` + Message string `json:"message"` } type RuleLevel string diff --git a/internal/pkg/cloudbeaver/graphql.go b/internal/pkg/cloudbeaver/graphql.go index 7ef3876c..74da5ba9 100644 --- a/internal/pkg/cloudbeaver/graphql.go +++ b/internal/pkg/cloudbeaver/graphql.go @@ -146,7 +146,7 @@ type AuditResDataV2 struct { SQLResults []AuditSQLResV2 `json:"sql_results"` } -type auditSQLReply struct { +type AuditSQLReply struct { Code int `json:"code" example:"0"` Message string `json:"message" example:"ok"` Data *AuditResDataV2 `json:"data"` @@ -176,7 +176,7 @@ func (r *MutationResolverImpl) AuditSQL(ctx context.Context, sql string, connect RuleTemplateName: directAuditParams.RuleTemplateName, } - reply := &auditSQLReply{} + reply := &AuditSQLReply{} if err = pkgHttp.POST(ctx, directAuditParams.SQLEAddr, header, req, reply); err != nil { return false, nil, err } diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index 1b5e755b..86171d7d 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -1210,7 +1210,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) isEnableSQLAudit(dbService *biz. } // callSQLEAudit 调用 SQLE 直接审核接口 -func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Context, sql string, dbService *biz.DBService) (*auditSQLReply, error) { +func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Context, sql string, dbService *biz.DBService) (*cloudbeaver.AuditSQLReply, error) { // 获取 SQLE 服务地址 target, err := sqlWorkbenchService.proxyTargetRepo.GetProxyTargetByName(ctx, _const.SqleComponentName) if err != nil { @@ -1234,7 +1234,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex } // 调用 SQLE 审核接口 - reply := &auditSQLReply{} + reply := &cloudbeaver.AuditSQLReply{} if err := pkgHttp.POST(ctx, sqleAddr, header, auditReq, reply); err != nil { return nil, fmt.Errorf("failed to call SQLE audit API: %v", err) } @@ -1247,7 +1247,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex } // interceptAndAddAuditResult 拦截响应并添加审核结果 -func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c echo.Context, next echo.HandlerFunc, userId string, auditResult *auditSQLReply, dbService *biz.DBService) error { +func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c echo.Context, next echo.HandlerFunc, userId string, auditResult *cloudbeaver.AuditSQLReply, dbService *biz.DBService) error { // 判断是否需要审批 allowQueryWhenLessThanAuditLevel := dbService.GetAllowQueryWhenLessThanAuditLevel() needApproval := sqlWorkbenchService.shouldRequireApproval(auditResult.Data.SQLResults, allowQueryWhenLessThanAuditLevel) @@ -1262,7 +1262,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) interceptAndAddAuditResult(c ech } // buildAuditResponseWithoutExecution 构造审核响应,不执行真实的 SQL -func (sqlWorkbenchService *SqlWorkbenchService) buildAuditResponseWithoutExecution(c echo.Context, userId string, auditResult *auditSQLReply, dbService *biz.DBService) error { +func (sqlWorkbenchService *SqlWorkbenchService) buildAuditResponseWithoutExecution(c echo.Context, userId string, auditResult *cloudbeaver.AuditSQLReply, dbService *biz.DBService) error { // 构造 SQL 条目列表 sqlItems := make([]StreamExecuteSQLItem, 0, len(auditResult.Data.SQLResults)) for _, sqlResult := range auditResult.Data.SQLResults { @@ -1314,7 +1314,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) buildAuditResponseWithoutExecuti } // saveAuditBlockedLog 记录审核拦截的操作日志 -func (sqlWorkbenchService *SqlWorkbenchService) saveAuditBlockedLog(c echo.Context, userId string, auditResult *auditSQLReply, dbService *biz.DBService) error { +func (sqlWorkbenchService *SqlWorkbenchService) saveAuditBlockedLog(c echo.Context, userId string, auditResult *cloudbeaver.AuditSQLReply, dbService *biz.DBService) error { ctx := context.Background() // 提取 session ID @@ -1371,7 +1371,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) saveAuditBlockedLog(c echo.Conte } // convertToAuditResults 将审核结果转换为 model.AuditResults 格式 -func (sqlWorkbenchService *SqlWorkbenchService) convertToAuditResults(sqlResult *auditSQLResV2) dbmodel.AuditResults { +func (sqlWorkbenchService *SqlWorkbenchService) convertToAuditResults(sqlResult *cloudbeaver.AuditSQLResV2) dbmodel.AuditResults { var auditResults dbmodel.AuditResults for _, result := range sqlResult.AuditResult { auditResult := dbmodel.AuditResult{ @@ -1406,7 +1406,7 @@ func extractSessionID(path string) string { } // executeAndAddAuditResult 执行真实请求并添加审核结果 -func (sqlWorkbenchService *SqlWorkbenchService) executeAndAddAuditResult(c echo.Context, next echo.HandlerFunc, auditResult *auditSQLReply, dbService *biz.DBService) error { +func (sqlWorkbenchService *SqlWorkbenchService) executeAndAddAuditResult(c echo.Context, next echo.HandlerFunc, auditResult *cloudbeaver.AuditSQLReply, dbService *biz.DBService) error { // 创建响应拦截器 srw := newStreamExecuteResponseWriter(c) cloudbeaverResBuf := srw.Buffer @@ -1502,19 +1502,11 @@ func (sqlWorkbenchService *SqlWorkbenchService) decodeResponseBody(body []byte, if !isGzip { return body, false, nil } - - gzipReader, err := gzip.NewReader(bytes.NewReader(body)) + gzipBytes, err := utils.DecodeGzipBytes(body) if err != nil { - return nil, true, fmt.Errorf("failed to create gzip reader: %w", err) + return nil, true, fmt.Errorf("Gzip decode error: %v", err) } - defer gzipReader.Close() - - decompressed, err := io.ReadAll(gzipReader) - if err != nil { - return nil, true, fmt.Errorf("failed to decompress gzip body: %w", err) - } - - return decompressed, true, nil + return gzipBytes, true, nil } // encodeResponseBody 将响应体按照 gzip 编码 @@ -1532,9 +1524,9 @@ func (sqlWorkbenchService *SqlWorkbenchService) encodeResponseBody(body []byte) } // mergeSQLEAuditResults 将 SQLE 审核结果整合到 sqls 数组中 -func (sqlWorkbenchService *SqlWorkbenchService) mergeSQLEAuditResults(data *StreamExecuteData, auditData *auditResDataV2, dbService *biz.DBService) { +func (sqlWorkbenchService *SqlWorkbenchService) mergeSQLEAuditResults(data *StreamExecuteData, auditData *cloudbeaver.AuditResDataV2, dbService *biz.DBService) { // 创建 SQL 到审核结果的映射 - sqlAuditMap := make(map[string]*auditSQLResV2) + sqlAuditMap := make(map[string]*cloudbeaver.AuditSQLResV2) for i := range auditData.SQLResults { sqlResult := &auditData.SQLResults[i] // 使用 exec_sql 作为 key,去除首尾空格和分号 @@ -1551,7 +1543,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) mergeSQLEAuditResults(data *Stre sqlItem := &data.SQLs[i] // 尝试从 originalSql 或 executedSql 匹配审核结果 - var matchedAuditResult *auditSQLResV2 + var matchedAuditResult *cloudbeaver.AuditSQLResV2 normalizedSQL := strings.TrimSpace(strings.TrimSuffix(sqlItem.SQLTuple.OriginalSQL, ";")) if auditResult, found := sqlAuditMap[normalizedSQL]; found { matchedAuditResult = auditResult @@ -1575,7 +1567,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) mergeSQLEAuditResults(data *Stre } // shouldRequireApproval 根据审核放行等级判断是否需要审批 -func (sqlWorkbenchService *SqlWorkbenchService) shouldRequireApproval(sqlResults []auditSQLResV2, allowQueryWhenLessThanAuditLevel string) bool { +func (sqlWorkbenchService *SqlWorkbenchService) shouldRequireApproval(sqlResults []cloudbeaver.AuditSQLResV2, allowQueryWhenLessThanAuditLevel string) bool { // 如果没有设置审核放行等级,那么直接放行 if allowQueryWhenLessThanAuditLevel == "" { return false @@ -1606,7 +1598,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) shouldRequireApproval(sqlResults } // convertSQLEAuditToViolatedRules 将 SQLE 审核结果转换为 violatedRules 格式 -func (sqlWorkbenchService *SqlWorkbenchService) convertSQLEAuditToViolatedRules(auditResult *auditSQLResV2) []StreamExecuteRule { +func (sqlWorkbenchService *SqlWorkbenchService) convertSQLEAuditToViolatedRules(auditResult *cloudbeaver.AuditSQLResV2) []StreamExecuteRule { var violatedRules []StreamExecuteRule // 将 SQLE 的 audit_result 转换为 violatedRules 格式 @@ -1755,29 +1747,4 @@ func (w *streamExecuteResponseWriter) WriteHeader(code int) { w.status = code } -// auditSQLReply SQLE 审核响应结构 -type auditSQLReply struct { - Code int `json:"code"` - Message string `json:"message"` - Data *auditResDataV2 `json:"data"` -} - -// auditResDataV2 审核结果数据 -type auditResDataV2 struct { - AuditLevel string `json:"audit_level"` - Score int32 `json:"score"` - PassRate float64 `json:"pass_rate"` - SQLResults []auditSQLResV2 `json:"sql_results"` -} -// auditSQLResV2 单个 SQL 审核结果 -type auditSQLResV2 struct { - Number uint `json:"number"` - ExecSQL string `json:"exec_sql"` - AuditResult []struct { - Level string `json:"level"` - Message string `json:"message"` - ExecutionFailed bool `json:"execution_failed"` - } `json:"audit_result"` - AuditLevel string `json:"audit_level"` -}