diff --git a/hooks/base.go b/hooks/base.go new file mode 100644 index 0000000..5149a72 --- /dev/null +++ b/hooks/base.go @@ -0,0 +1,14 @@ +package hooks + +import "context" + +type Base struct { +} + +func (b *Base) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return ctx, nil +} + +func (b *Base) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return ctx, nil +} diff --git a/hooks/safetyhooks/safetyhooks.go b/hooks/safetyhooks/safetyhooks.go new file mode 100644 index 0000000..45e4790 --- /dev/null +++ b/hooks/safetyhooks/safetyhooks.go @@ -0,0 +1,42 @@ +package safetyhooks + +import ( + "database/sql/driver" + "fmt" + "runtime" + + "github.com/gchaincl/sqlhooks/v2/hooks" +) + +type Hook struct { + hooks.Base +} + +func New() *Hook { + return &Hook{} +} + +// safeRows wrap a driver.Rows interface in order to implement Sharp-Edged +// Finalizers based on https://crawshaw.io/blog/sharp-edged-finalizers. +type safeRows struct { + driver.Rows +} + +func (s *safeRows) Close() { + runtime.SetFinalizer(s, nil) + s.Rows.Close() +} + +func doPanic() { + _, file, line, _ := runtime.Caller(1) + panic(fmt.Sprintf("%s:%d: row not closed", file, line)) +} + +func (h *Hook) Rows(r driver.Rows) driver.Rows { + s := &safeRows{r} + runtime.SetFinalizer(s, func(*safeRows) { + doPanic() + }) + + return r +} diff --git a/hooks/safetyhooks/safetyhooks_test.go b/hooks/safetyhooks/safetyhooks_test.go new file mode 100644 index 0000000..3002891 --- /dev/null +++ b/hooks/safetyhooks/safetyhooks_test.go @@ -0,0 +1,41 @@ +package safetyhooks + +import ( + "database/sql" + "testing" + + "github.com/gchaincl/sqlhooks/v2" + "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" +) + +func setupTestDB(t *testing.T, hooks sqlhooks.Hooks) *sql.DB { + var ( + err error + name = "final" + ) + + sql.Register(name, sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, hooks)) + db, err := sql.Open(name, ":memory:") + require.NoError(t, err) + + _, err = db.Exec("CREATE TABLE test(id int)") + require.NoError(t, err) + + _, err = db.Exec("INSERT INTO test VALUES(1)") + require.NoError(t, err) + + return db +} + +func doQuery(db *sql.DB, query string) (*sql.Rows, error) { + return db.Query(query) +} + +func TestFinalizers(t *testing.T) { + hooks := New() + db := setupTestDB(t, hooks) + + _, err := doQuery(db, "SELECT * from test") + require.NoError(t, err) +} diff --git a/sqlhooks.go b/sqlhooks.go index 1da05ba..59a80bb 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -23,6 +23,12 @@ type OnErrorer interface { OnError(ctx context.Context, err error, query string, args ...interface{}) error } +// RowsWrapper is an optional interface for Hooks representing the hability of +// wrapper rows. +type RowsWrapper interface { + Rows(r driver.Rows) driver.Rows +} + func handlerErr(ctx context.Context, hooks Hooks, err error, query string, args ...interface{}) error { h, ok := hooks.(OnErrorer) if !ok { @@ -219,6 +225,10 @@ func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args return nil, err } + if w, ok := conn.hooks.(RowsWrapper); ok { + results = w.Rows(results) + } + return results, err }