Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ type Client struct {
Transport *transport
}

// NewClient constructs a new client given a URL to a Postgrest instance.
func NewClient(rawURL, schema string, headers map[string]string) *Client {
// NewClientWithError constructs a new client given a URL to a Postgrest instance.
func NewClientWithError(rawURL, schema string, headers map[string]string) (*Client, error) {
// Create URL from rawURL
baseURL, err := url.Parse(rawURL)
if err != nil {
return &Client{ClientError: err}
return nil, err
}

t := transport{
Expand Down Expand Up @@ -55,30 +55,42 @@ func NewClient(rawURL, schema string, headers map[string]string) *Client {
c.Transport.header.Set(key, value)
}

return &c
return &c, nil
}

func (c *Client) Ping() bool {
req, err := http.NewRequest("GET", path.Join(c.Transport.baseURL.Path, ""), nil)
// NewClient constructs a new client given a URL to a Postgrest instance.
func NewClient(rawURL, schema string, headers map[string]string) *Client {
client, err := NewClientWithError(rawURL, schema, headers)
if err != nil {
c.ClientError = err
return &Client{ClientError: err}
}
return client
}

return false
func (c *Client) PingWithError() error {
req, err := http.NewRequest("GET", path.Join(c.Transport.baseURL.Path, ""), nil)
if err != nil {
return err
}

resp, err := c.session.Do(req)
if err != nil {
c.ClientError = err

return false
return err
}

if resp.Status != "200 OK" {
c.ClientError = errors.New("ping failed")
return errors.New("ping failed")
}

return nil
}

func (c *Client) Ping() bool {
err := c.PingWithError()
if err != nil {
c.ClientError = err
return false
}

return true
}

Expand Down Expand Up @@ -106,16 +118,15 @@ func (c *Client) From(table string) *QueryBuilder {
return &QueryBuilder{client: c, tableName: table, headers: map[string]string{}, params: map[string]string{}}
}

// Rpc executes a Postgres function (a.k.a., Remote Prodedure Call), given the
// RpcWithError executes a Postgres function (a.k.a., Remote Prodedure Call), given the
// function name and, optionally, a body, returning the result as a string.
func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
func (c *Client) RpcWithError(name string, count string, rpcBody interface{}) (string, error) {
// Get body if it exists
var byteBody []byte = nil
if rpcBody != nil {
jsonBody, err := json.Marshal(rpcBody)
if err != nil {
c.ClientError = err
return ""
return "", err
}
byteBody = jsonBody
}
Expand All @@ -124,8 +135,7 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
url := path.Join(c.Transport.baseURL.Path, "rpc", name)
req, err := http.NewRequest("POST", url, readerBody)
if err != nil {
c.ClientError = err
return ""
return "", err
}

if count != "" && (count == `exact` || count == `planned` || count == `estimated`) {
Expand All @@ -134,24 +144,32 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {

resp, err := c.session.Do(req)
if err != nil {
c.ClientError = err
return ""
return "", err
}

body, err := io.ReadAll(resp.Body)
if err != nil {
c.ClientError = err
return ""
return "", err
}

result := string(body)

err = resp.Body.Close()
if err != nil {
return "", err
}

return result, nil
}

// Rpc executes a Postgres function (a.k.a., Remote Prodedure Call), given the
// function name and, optionally, a body, returning the result as a string.
func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
result, err := c.RpcWithError(name, count, rpcBody)
if err != nil {
c.ClientError = err
return ""
}

return result
}

Expand Down
19 changes: 9 additions & 10 deletions execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ type ExecuteError struct {
Message string `json:"message"`
}

func executeHelper(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) ([]byte, countType, error) {
if client.ClientError != nil {
return nil, 0, client.ClientError
func executeHelper(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string, err error) ([]byte, countType, error) {
if err != nil {
return nil, 0, err
}

readerBody := bytes.NewBuffer(body)
baseUrl := path.Join(append([]string{client.Transport.baseURL.Path}, urlFragments...)...)
req, err := http.NewRequestWithContext(ctx, method, baseUrl, readerBody)
Expand Down Expand Up @@ -87,17 +86,17 @@ func executeHelper(ctx context.Context, client *Client, method string, body []by
return respBody, count, nil
}

func executeString(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) (string, countType, error) {
resp, count, err := executeHelper(ctx, client, method, body, urlFragments, headers, params)
func executeString(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string, err error) (string, countType, error) {
resp, count, err := executeHelper(ctx, client, method, body, urlFragments, headers, params, err)
return string(resp), count, err
}

func execute(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) ([]byte, countType, error) {
return executeHelper(ctx, client, method, body, urlFragments, headers, params)
func execute(ctx context.Context, client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string, err error) ([]byte, countType, error) {
return executeHelper(ctx, client, method, body, urlFragments, headers, params, err)
}

func executeTo(ctx context.Context, client *Client, method string, body []byte, to interface{}, urlFragments []string, headers map[string]string, params map[string]string) (countType, error) {
resp, count, err := executeHelper(ctx, client, method, body, urlFragments, headers, params)
func executeTo(ctx context.Context, client *Client, method string, body []byte, to interface{}, urlFragments []string, headers map[string]string, params map[string]string, err error) (countType, error) {
resp, count, err := executeHelper(ctx, client, method, body, urlFragments, headers, params, err)

if err != nil {
return count, err
Expand Down
33 changes: 19 additions & 14 deletions filterbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package postgrest
import (
"context"
"encoding/json"
"errors"
"fmt"
"regexp"
"slices"
"strconv"
"strings"
)
Expand All @@ -17,43 +19,44 @@ type FilterBuilder struct {
tableName string
headers map[string]string
params map[string]string
err error
}

// ExecuteString runs the PostgREST query, returning the result as a JSON
// string.
func (f *FilterBuilder) ExecuteString() (string, int64, error) {
return executeString(context.Background(), f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params)
return executeString(context.Background(), f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params, f.err)
}

// ExecuteStringWithContext runs the PostgREST query, returning the result as
// a JSON string.
func (f *FilterBuilder) ExecuteStringWithContext(ctx context.Context) (string, int64, error) {
return executeString(ctx, f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params)
return executeString(ctx, f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params, f.err)
}

// Execute runs the PostgREST query, returning the result as a byte slice.
func (f *FilterBuilder) Execute() ([]byte, int64, error) {
return execute(context.Background(), f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params)
return execute(context.Background(), f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params, f.err)
}

// ExecuteWithContext runs the PostgREST query with the given context,
// returning the result as a byte slice.
func (f *FilterBuilder) ExecuteWithContext(ctx context.Context) ([]byte, int64, error) {
return execute(ctx, f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params)
return execute(ctx, f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params, f.err)
}

// ExecuteTo runs the PostgREST query, encoding the result to the supplied
// interface. Note that the argument for the to parameter should always be a
// reference to a slice.
func (f *FilterBuilder) ExecuteTo(to interface{}) (countType, error) {
return executeTo(context.Background(), f.client, f.method, f.body, to, []string{f.tableName}, f.headers, f.params)
return executeTo(context.Background(), f.client, f.method, f.body, to, []string{f.tableName}, f.headers, f.params, f.err)
}

// ExecuteToWithContext runs the PostgREST query with the given context,
// encoding the result to the supplied interface. Note that the argument for
// the to parameter should always be a reference to a slice.
func (f *FilterBuilder) ExecuteToWithContext(ctx context.Context, to interface{}) (countType, error) {
return executeTo(ctx, f.client, f.method, f.body, to, []string{f.tableName}, f.headers, f.params)
return executeTo(ctx, f.client, f.method, f.body, to, []string{f.tableName}, f.headers, f.params, f.err)
}

var filterOperators = []string{"eq", "neq", "gt", "gte", "lt", "lte", "like", "ilike", "is", "in", "cs", "cd", "sl", "sr", "nxl", "nxr", "adj", "ov", "fts", "plfts", "phfts", "wfts"}
Expand All @@ -74,19 +77,16 @@ func (f *FilterBuilder) appendFilter(column, filterValue string) *FilterBuilder
}

func isOperator(value string) bool {
for _, operator := range filterOperators {
if value == operator {
return true
}
}
return false
return slices.Contains(filterOperators, value)
}

// Filter adds a filtering operator to the query. For a list of available
// operators, see: https://postgrest.org/en/stable/api.html#operators
func (f *FilterBuilder) Filter(column, operator, value string) *FilterBuilder {
if !isOperator(operator) {
f.client.ClientError = fmt.Errorf("invalid filter operator")
err := fmt.Errorf("invalid Filter operator: %s", operator)
f.client.ClientError = err
f.err = errors.Join(f.err, err)
return f
}
return f.appendFilter(column, fmt.Sprintf("%s.%s", operator, value))
Expand Down Expand Up @@ -200,6 +200,7 @@ func (f *FilterBuilder) ContainsObject(column string, value interface{}) *Filter
sum, err := json.Marshal(value)
if err != nil {
f.client.ClientError = err
f.err = errors.Join(f.err, fmt.Errorf("error marshaling value for ContainsObject: %w", err))
return f
}
return f.appendFilter(column, "cs."+string(sum))
Expand All @@ -208,7 +209,9 @@ func (f *FilterBuilder) ContainsObject(column string, value interface{}) *Filter
func (f *FilterBuilder) ContainedByObject(column string, value interface{}) *FilterBuilder {
sum, err := json.Marshal(value)
if err != nil {
err := fmt.Errorf("error marshaling value for ContainedByObject: %w", err)
f.client.ClientError = err
f.err = errors.Join(f.err, err)
return f
}
return f.appendFilter(column, "cd."+string(sum))
Expand Down Expand Up @@ -257,7 +260,9 @@ func (f *FilterBuilder) TextSearch(column, userQuery, config, tsType string) *Fi
} else if tsType == "" {
typePart = ""
} else {
f.client.ClientError = fmt.Errorf("invalid text search type")
err := fmt.Errorf("invalid text search type: %s", tsType)
f.client.ClientError = err
f.err = errors.Join(f.err, err)
return f
}
if config != "" {
Expand Down
Loading