@@ -15,6 +15,7 @@ import (
1515 "github.com/hashicorp/go-hclog"
1616 goplugin "github.com/hashicorp/go-plugin"
1717 "github.com/jackc/pgx/v5/pgproto3"
18+ "github.com/prometheus/client_golang/prometheus"
1819 "github.com/spf13/cast"
1920 "google.golang.org/grpc"
2021)
@@ -30,15 +31,17 @@ const (
3031 OutputsField string = "outputs"
3132 TokensField string = "tokens"
3233 StringField string = "String"
34+ ResponseTypeField string = "response_type"
3335
3436 DeepLearningModel string = "deep_learning_model"
3537 Libinjection string = "libinjection"
3638
37- ErrorLevel string = "error"
38- ExceptionLevel string = "EXCEPTION"
39- ErrorNumber string = "42000"
40- DetectionMessage string = "SQL injection detected"
41- ErrorResponseMessage string = "Back off, you're not welcome here."
39+ ResponseType string = "error"
40+ ErrorSeverity string = "EXCEPTION"
41+ ErrorNumber string = "42000"
42+ ErrorMessage string = "SQL injection detected"
43+ ErrorDetail string = "Back off, you're not welcome here."
44+ LogLevel string = "error"
4245
4346 TokenizeAndSequencePath string = "/tokenize_and_sequence"
4447 PredictPath string = "/v1/models/%s/versions/%s:predict"
@@ -55,6 +58,12 @@ type Plugin struct {
5558 ServingAPIAddress string
5659 ModelName string
5760 ModelVersion string
61+ ResponseType string
62+ ErrorMessage string
63+ ErrorSeverity string
64+ ErrorNumber string
65+ ErrorDetail string
66+ LogLevel string
5867}
5968
6069type InjectionDetectionPlugin struct {
@@ -139,7 +148,7 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
139148 if err != nil {
140149 p .Logger .Error ("Failed to make POST request" , ErrorField , err )
141150 if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
142- return p .errorResponse (
151+ return p .prepareResponse (
143152 req ,
144153 map [string ]any {
145154 QueryField : queryString ,
@@ -163,7 +172,7 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
163172 if err != nil {
164173 p .Logger .Error ("Failed to make POST request" , ErrorField , err )
165174 if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
166- return p .errorResponse (
175+ return p .prepareResponse (
167176 req ,
168177 map [string ]any {
169178 QueryField : queryString ,
@@ -189,8 +198,8 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
189198 }
190199
191200 Detections .With (map [string ]string {DetectorField : DeepLearningModel }).Inc ()
192- p .Logger .Warn (DetectionMessage , ScoreField , score , DetectorField , DeepLearningModel )
193- return p .errorResponse (
201+ p .Logger .Warn (p . ErrorMessage , ScoreField , score , DetectorField , DeepLearningModel )
202+ return p .prepareResponse (
194203 req ,
195204 map [string ]any {
196205 QueryField : queryString ,
@@ -200,8 +209,8 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
200209 ), nil
201210 } else if p .EnableLibinjection && injection && ! p .LibinjectionPermissiveMode {
202211 Detections .With (map [string ]string {DetectorField : Libinjection }).Inc ()
203- p .Logger .Warn (DetectionMessage , DetectorField , Libinjection )
204- return p .errorResponse (
212+ p .Logger .Warn (p . ErrorMessage , DetectorField , Libinjection )
213+ return p .prepareResponse (
205214 req ,
206215 map [string ]any {
207216 QueryField : queryString ,
@@ -224,35 +233,36 @@ func (p *Plugin) isSQLi(query string) bool {
224233 // Check if the query is an SQL injection using libinjection.
225234 injection , _ := libinjection .IsSQLi (query )
226235 if injection {
227- p .Logger .Warn (DetectionMessage , DetectorField , Libinjection )
236+ p .Logger .Warn (p . ErrorMessage , DetectorField , Libinjection )
228237 }
229238 p .Logger .Trace ("SQLInjection" , IsInjectionField , cast .ToString (injection ))
230239 return injection
231240}
232241
233- func (p * Plugin ) errorResponse (req * v1.Struct , fields map [string ]any ) * v1.Struct {
234- Preventions .Inc ()
242+ func (p * Plugin ) prepareResponse (req * v1.Struct , fields map [string ]any ) * v1.Struct {
243+ Preventions .With (prometheus. Labels { ResponseTypeField : p . ResponseType }). Inc ()
235244
236- // Create a PostgreSQL error response.
237- errResp := postgres .ErrorResponse (
238- DetectionMessage ,
239- ExceptionLevel ,
240- ErrorNumber ,
241- ErrorResponseMessage ,
242- )
245+ var encapsulatedResponse []byte
243246
244- // Create a ready for query response.
245- readyForQuery := & pgproto3.ReadyForQuery {TxStatus : 'I' }
246- // TODO: Decide whether to terminate the connection.
247- response , err := readyForQuery .Encode (errResp )
248- if err != nil {
249- p .Logger .Error ("Failed to encode ready for query response" , ErrorField , err )
250- return req
247+ if p .ResponseType == "error" {
248+ // Create a PostgreSQL error response.
249+ encapsulatedResponse = postgres .ErrorResponse (
250+ p .ErrorMessage ,
251+ p .ErrorSeverity ,
252+ ErrorNumber ,
253+ ErrorDetail ,
254+ )
255+ } else {
256+ // Create a PostgreSQL empty query response.
257+ encapsulatedResponse , _ = (& pgproto3.EmptyQueryResponse {}).Encode (nil )
251258 }
252259
260+ // Create and encode a ready for query response.
261+ response , _ := (& pgproto3.ReadyForQuery {TxStatus : 'I' }).Encode (encapsulatedResponse )
262+
253263 signals , err := v1 .NewList ([]any {
254264 sdkAct .Terminate ().ToMap (),
255- sdkAct .Log (ErrorLevel , DetectionMessage , fields ).ToMap (),
265+ sdkAct .Log (p . LogLevel , p . ErrorMessage , fields ).ToMap (),
256266 })
257267 if err != nil {
258268 p .Logger .Error ("Failed to create signals" , ErrorField , err )
0 commit comments