diff --git a/lib/darwin_amd64/libsql_experimental.a b/lib/darwin_amd64/libsql_experimental.a index e7d2679..ee93ca6 100644 Binary files a/lib/darwin_amd64/libsql_experimental.a and b/lib/darwin_amd64/libsql_experimental.a differ diff --git a/lib/include/libsql.h b/lib/include/libsql.h index bad7a7c..8178980 100644 --- a/lib/include/libsql.h +++ b/lib/include/libsql.h @@ -3,6 +3,16 @@ #include +#define LIBSQL_INT 1 + +#define LIBSQL_FLOAT 2 + +#define LIBSQL_TEXT 3 + +#define LIBSQL_BLOB 4 + +#define LIBSQL_NULL 5 + typedef struct libsql_connection libsql_connection; typedef struct libsql_database libsql_database; @@ -17,6 +27,16 @@ typedef struct libsql_stmt libsql_stmt; typedef const libsql_database *libsql_database_t; +typedef struct { + const char *db_path; + const char *primary_url; + const char *auth_token; + char read_your_writes; + const char *encryption_key; + int sync_interval; + char with_webpki; +} libsql_config; + typedef const libsql_connection *libsql_connection_t; typedef const libsql_stmt *libsql_stmt_t; @@ -46,16 +66,36 @@ int libsql_open_sync(const char *db_path, libsql_database_t *out_db, const char **out_err_msg); +int libsql_open_sync_with_webpki(const char *db_path, + const char *primary_url, + const char *auth_token, + char read_your_writes, + const char *encryption_key, + libsql_database_t *out_db, + const char **out_err_msg); + +int libsql_open_sync_with_config(libsql_config config, libsql_database_t *out_db, const char **out_err_msg); + int libsql_open_ext(const char *url, libsql_database_t *out_db, const char **out_err_msg); int libsql_open_file(const char *url, libsql_database_t *out_db, const char **out_err_msg); int libsql_open_remote(const char *url, const char *auth_token, libsql_database_t *out_db, const char **out_err_msg); +int libsql_open_remote_with_webpki(const char *url, + const char *auth_token, + libsql_database_t *out_db, + const char **out_err_msg); + void libsql_close(libsql_database_t db); int libsql_connect(libsql_database_t db, libsql_connection_t *out_conn, const char **out_err_msg); +int libsql_load_extension(libsql_connection_t conn, + const char *path, + const char *entry_point, + const char **out_err_msg); + int libsql_reset(libsql_connection_t conn, const char **out_err_msg); void libsql_disconnect(libsql_connection_t conn); @@ -76,6 +116,8 @@ int libsql_query_stmt(libsql_stmt_t stmt, libsql_rows_t *out_rows, const char ** int libsql_execute_stmt(libsql_stmt_t stmt, const char **out_err_msg); +int libsql_reset_stmt(libsql_stmt_t stmt, const char **out_err_msg); + void libsql_free_stmt(libsql_stmt_t stmt); int libsql_query(libsql_connection_t conn, const char *sql, libsql_rows_t *out_rows, const char **out_err_msg); diff --git a/libsql.go b/libsql.go index ce37502..9ba21f4 100644 --- a/libsql.go +++ b/libsql.go @@ -24,6 +24,7 @@ import ( sqldriver "database/sql/driver" "errors" "fmt" + "golang.org/x/exp/slices" "io" "net/url" "regexp" @@ -40,11 +41,17 @@ func init() { sql.Register("libsql", driver{}) } +type extension struct { + path string + entryPoint string +} + type config struct { authToken *string readYourWrites *bool encryptionKey *string syncInterval *time.Duration + extensions []extension } type Option interface { @@ -103,6 +110,16 @@ func WithSyncInterval(interval time.Duration) Option { }) } +func WithExtension(path, entryPoint string) Option { + return option(func(o *config) error { + if slices.ContainsFunc(o.extensions, func(e extension) bool { return e.path == path }) { + return fmt.Errorf("extension %s already added", path) + } + o.extensions = append(o.extensions, extension{path, entryPoint}) + return nil + }) +} + func NewEmbeddedReplicaConnector(dbPath string, primaryUrl string, opts ...Option) (*Connector, error) { var config config errs := make([]error, 0, len(opts)) @@ -130,7 +147,7 @@ func NewEmbeddedReplicaConnector(dbPath string, primaryUrl string, opts ...Optio if config.syncInterval != nil { syncInterval = *config.syncInterval } - return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, readYourWrites, encryptionKey, syncInterval) + return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, readYourWrites, encryptionKey, syncInterval, config.extensions) } type driver struct{} @@ -191,7 +208,7 @@ func openRemoteConnector(primaryUrl, authToken string) (*Connector, error) { return &Connector{nativeDbPtr: nativeDbPtr}, nil } -func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, readYourWrites bool, encryptionKey string, syncInterval time.Duration) (*Connector, error) { +func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, readYourWrites bool, encryptionKey string, syncInterval time.Duration, extensions []extension) (*Connector, error) { var closeCh chan struct{} var closeAckCh chan struct{} nativeDbPtr, err := libsqlOpenWithSync(dbPath, primaryUrl, authToken, readYourWrites, encryptionKey) @@ -224,10 +241,11 @@ func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, readYour } }() } - return &Connector{nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil + return &Connector{extensions: extensions, nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil } type Connector struct { + extensions []extension nativeDbPtr C.libsql_database_t closeCh chan<- struct{} closeAckCh <-chan struct{} @@ -256,6 +274,26 @@ func (c *Connector) Connect(ctx context.Context) (sqldriver.Conn, error) { if err != nil { return nil, err } + for _, ext := range c.extensions { + err := func() error { + extPath := C.CString(ext.path) + defer C.free(unsafe.Pointer(extPath)) + var extEntryPoint *C.char = nil + if ext.entryPoint != "" { + extEntryPoint = C.CString(ext.entryPoint) + defer C.free(unsafe.Pointer(extEntryPoint)) + } + var errMsg *C.char + statusCode := C.libsql_load_extension(nativeConnPtr, extPath, extEntryPoint, &errMsg) + if statusCode != 0 { + return libsqlError(fmt.Sprintf("failed to load extension %s %s", ext.path, ext.entryPoint), statusCode, errMsg) + } + return nil + }() + if err != nil { + return nil, err + } + } return &conn{nativePtr: nativeConnPtr}, nil }