@@ -3,7 +3,6 @@ package plugin
33import (
44 "context"
55 "encoding/json"
6- "fmt"
76 "net/http"
87 "net/http/httptest"
98 "testing"
@@ -71,28 +70,13 @@ func Test_errorResponse(t *testing.T) {
7170
7271func Test_OnTrafficFromClinet (t * testing.T ) {
7372 p := & Plugin {
74- Logger : hclog .NewNullLogger (),
75- ModelName : "sqli_model" ,
76- ModelVersion : "2" ,
73+ Logger : hclog .NewNullLogger (),
7774 }
7875
7976 server := httptest .NewServer (
8077 http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
8178 switch r .URL .Path {
82- case TokenizeAndSequencePath :
83- w .WriteHeader (http .StatusOK )
84- w .Header ().Set ("Content-Type" , "application/json" )
85- // This is the tokenized query:
86- // {"query":"select * from users where id = 1 or 1=1"}
87- resp := map [string ][]float32 {
88- "tokens" : {
89- 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 3 , 6 , 5 , 73 , 7 , 68 , 4 , 11 , 12 ,
90- },
91- }
92- data , _ := json .Marshal (resp )
93- _ , err := w .Write (data )
94- require .NoError (t , err )
95- case fmt .Sprintf (PredictPath , p .ModelName , p .ModelVersion ):
79+ case PredictPath :
9680 w .WriteHeader (http .StatusOK )
9781 w .Header ().Set ("Content-Type" , "application/json" )
9882 // This is the output of the deep learning model.
@@ -107,8 +91,7 @@ func Test_OnTrafficFromClinet(t *testing.T) {
10791 )
10892 defer server .Close ()
10993
110- p .TokenizerAPIAddress = server .URL
111- p .ServingAPIAddress = server .URL
94+ p .PredictionAPIAddress = server .URL
11295
11396 query := pgproto3.Query {String : "SELECT * FROM users WHERE id = 1 OR 1=1" }
11497 queryBytes , err := query .Encode (nil )
@@ -136,17 +119,13 @@ func Test_OnTrafficFromClinet(t *testing.T) {
136119func Test_OnTrafficFromClinetFailedTokenization (t * testing.T ) {
137120 plugins := []* Plugin {
138121 {
139- Logger : hclog .NewNullLogger (),
140- ModelName : "sqli_model" ,
141- ModelVersion : "2" ,
122+ Logger : hclog .NewNullLogger (),
142123 // If libinjection is enabled, the response should contain the "response" field,
143124 // and the "signals" field, which means the plugin will terminate the request.
144125 EnableLibinjection : true ,
145126 },
146127 {
147- Logger : hclog .NewNullLogger (),
148- ModelName : "sqli_model" ,
149- ModelVersion : "2" ,
128+ Logger : hclog .NewNullLogger (),
150129 // If libinjection is disabled, the response should not contain the "response" field,
151130 // and the "signals" field, which means the plugin will not terminate the request.
152131 EnableLibinjection : false ,
@@ -156,7 +135,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
156135 server := httptest .NewServer (
157136 http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
158137 switch r .URL .Path {
159- case TokenizeAndSequencePath :
138+ case PredictPath :
160139 w .WriteHeader (http .StatusInternalServerError )
161140 default :
162141 w .WriteHeader (http .StatusNotFound )
@@ -166,8 +145,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
166145 defer server .Close ()
167146
168147 for i := range plugins {
169- plugins [i ].TokenizerAPIAddress = server .URL
170- plugins [i ].ServingAPIAddress = server .URL
148+ plugins [i ].PredictionAPIAddress = server .URL
171149
172150 query := pgproto3.Query {String : "SELECT * FROM users WHERE id = 1 OR 1=1" }
173151 queryBytes , err := query .Encode (nil )
@@ -204,43 +182,22 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
204182func Test_OnTrafficFromClinetFailedPrediction (t * testing.T ) {
205183 plugins := []* Plugin {
206184 {
207- Logger : hclog .NewNullLogger (),
208- ModelName : "sqli_model" ,
209- ModelVersion : "2" ,
185+ Logger : hclog .NewNullLogger (),
210186 // If libinjection is disabled, the response should not contain the "response" field,
211187 // and the "signals" field, which means the plugin will not terminate the request.
212188 EnableLibinjection : false ,
213189 },
214190 {
215- Logger : hclog .NewNullLogger (),
216- ModelName : "sqli_model" ,
217- ModelVersion : "2" ,
191+ Logger : hclog .NewNullLogger (),
218192 // If libinjection is enabled, the response should contain the "response" field,
219193 // and the "signals" field, which means the plugin will terminate the request.
220194 EnableLibinjection : true ,
221195 },
222196 }
223-
224- // This is the same for both plugins.
225- predictPath := fmt .Sprintf (PredictPath , plugins [0 ].ModelName , plugins [1 ].ModelVersion )
226-
227197 server := httptest .NewServer (
228198 http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
229199 switch r .URL .Path {
230- case TokenizeAndSequencePath :
231- w .WriteHeader (http .StatusOK )
232- w .Header ().Set ("Content-Type" , "application/json" )
233- // This is the tokenized query:
234- // {"query":"select * from users where id = 1 or 1=1"}
235- resp := map [string ][]float32 {
236- "tokens" : {
237- 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 3 , 6 , 5 , 73 , 7 , 68 , 4 , 11 , 12 ,
238- },
239- }
240- data , _ := json .Marshal (resp )
241- _ , err := w .Write (data )
242- require .NoError (t , err )
243- case predictPath :
200+ case PredictPath :
244201 w .WriteHeader (http .StatusInternalServerError )
245202 default :
246203 w .WriteHeader (http .StatusNotFound )
@@ -250,8 +207,7 @@ func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) {
250207 defer server .Close ()
251208
252209 for i := range plugins {
253- plugins [i ].TokenizerAPIAddress = server .URL
254- plugins [i ].ServingAPIAddress = server .URL
210+ plugins [i ].PredictionAPIAddress = server .URL
255211
256212 query := pgproto3.Query {String : "SELECT * FROM users WHERE id = 1 OR 1=1" }
257213 queryBytes , err := query .Encode (nil )
0 commit comments