Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdks/community/go/pkg/client/sse/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"
"time"

"github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -38,7 +39,7 @@ type Frame struct {

type StreamOptions struct {
Context context.Context
Payload interface{}
Payload types.RunAgentInput
Headers map[string]string
}

Expand Down
94 changes: 44 additions & 50 deletions sdks/community/go/pkg/client/sse/client_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,54 @@ import (
"testing"
"time"

"github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// testPayload returns a simple RunAgentInput for testing
func testPayload() types.RunAgentInput {
return types.RunAgentInput{
ThreadId: "test-thread",
RunId: "test-run",
}
}

func TestStream(t *testing.T) {
t.Run("successful stream", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, "text/event-stream", r.Header.Get("Accept"))

w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)

flusher, ok := w.(http.Flusher)
require.True(t, ok)

fmt.Fprintf(w, "data: first message\n\n")
flusher.Flush()

fmt.Fprintf(w, "data: second message\n\n")
flusher.Flush()

fmt.Fprintf(w, "data: {\"type\":\"json\",\"value\":123}\n\n")
flusher.Flush()
}))
defer server.Close()

client := NewClient(Config{
Endpoint: server.URL,
BufferSize: 10,
})

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

frames, errors, err := client.Stream(StreamOptions{
Context: ctx,
Payload: map[string]string{"test": "data"},
Payload: testPayload(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -106,10 +115,10 @@ func TestStream(t *testing.T) {

frames, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

select {
case frame := <-frames:
assert.Equal(t, "line1\nline2\nline3", string(frame.Data))
Expand Down Expand Up @@ -170,13 +179,13 @@ func TestStream(t *testing.T) {

_, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)
})
}
})

t.Run("custom headers", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "custom-value", r.Header.Get("X-Custom-Header"))
Expand All @@ -195,15 +204,15 @@ func TestStream(t *testing.T) {

_, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
Headers: map[string]string{
"X-Custom-Header": "custom-value",
"X-Another-Header": "another-value",
},
})
require.NoError(t, err)
})

t.Run("error responses", func(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -250,16 +259,16 @@ func TestStream(t *testing.T) {
client := NewClient(Config{
Endpoint: server.URL,
})

_, _, err := client.Stream(StreamOptions{
Payload: struct{}{},
Payload: testPayload(),
})
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectedErr)
})
}
})

t.Run("context cancellation", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
Expand All @@ -283,13 +292,13 @@ func TestStream(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()

frames, errors, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

messageCount := 0
for {
select {
Expand All @@ -309,32 +318,17 @@ func TestStream(t *testing.T) {
}
})

t.Run("invalid payload marshaling", func(t *testing.T) {
client := NewClient(Config{
Endpoint: "http://localhost",
})

// Create an unmarshalable payload
invalidPayload := make(chan int)

_, _, err := client.Stream(StreamOptions{
Payload: invalidPayload,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to marshal payload")
})

t.Run("invalid endpoint", func(t *testing.T) {
client := NewClient(Config{
Endpoint: "http://[::1]:namedport", // Invalid URL
})

_, _, err := client.Stream(StreamOptions{
Payload: struct{}{},
Payload: testPayload(),
})
require.Error(t, err)
})

t.Run("concurrent reads", func(t *testing.T) {
messageCount := 50
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -358,13 +352,13 @@ func TestStream(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

frames, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

var wg sync.WaitGroup
received := make(map[string]bool)
mu := sync.Mutex{}
Expand Down Expand Up @@ -410,13 +404,13 @@ func TestStream(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

frames, errors, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

// Should receive initial message
select {
case frame := <-frames:
Expand Down Expand Up @@ -463,13 +457,13 @@ func TestStream(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

frames, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
require.NoError(t, err)

// Consume all frames
go func() {
for range frames {
Expand Down Expand Up @@ -691,12 +685,12 @@ func BenchmarkStream(b *testing.B) {

frames, _, err := client.Stream(StreamOptions{
Context: ctx,
Payload: struct{}{},
Payload: testPayload(),
})
if err != nil {
b.Fatal(err)
}

count := 0
for range frames {
count++
Expand Down
Loading
Loading