Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion server/cmd/api/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
"sync"
"time"

"github.com/onkernel/kernel-images/server/lib/cdpmonitor"
"github.com/onkernel/kernel-images/server/lib/devtoolsproxy"
"github.com/onkernel/kernel-images/server/lib/events"
"github.com/onkernel/kernel-images/server/lib/logger"
"github.com/onkernel/kernel-images/server/lib/nekoclient"
oapi "github.com/onkernel/kernel-images/server/lib/oapi"
Expand Down Expand Up @@ -68,11 +70,24 @@ type ApiService struct {
// xvfbResizeMu serializes background Xvfb restarts to prevent races
// when multiple CDP fast-path resizes fire in quick succession.
xvfbResizeMu sync.Mutex

// CDP event pipeline and cdpMonitor.
captureSession *events.CaptureSession
cdpMonitor *cdpmonitor.Monitor
monitorMu sync.Mutex
}

var _ oapi.StrictServerInterface = (*ApiService)(nil)

func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFactory, upstreamMgr *devtoolsproxy.UpstreamManager, stz scaletozero.Controller, nekoAuthClient *nekoclient.AuthClient) (*ApiService, error) {
func New(
recordManager recorder.RecordManager,
factory recorder.FFmpegRecorderFactory,
upstreamMgr *devtoolsproxy.UpstreamManager,
stz scaletozero.Controller,
nekoAuthClient *nekoclient.AuthClient,
captureSession *events.CaptureSession,
displayNum int,
) (*ApiService, error) {
switch {
case recordManager == nil:
return nil, fmt.Errorf("recordManager cannot be nil")
Expand All @@ -82,8 +97,12 @@ func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFa
return nil, fmt.Errorf("upstreamMgr cannot be nil")
case nekoAuthClient == nil:
return nil, fmt.Errorf("nekoAuthClient cannot be nil")
case captureSession == nil:
return nil, fmt.Errorf("captureSession cannot be nil")
}

mon := cdpmonitor.New(upstreamMgr, captureSession.Publish, displayNum)

return &ApiService{
recordManager: recordManager,
factory: factory,
Expand All @@ -94,6 +113,8 @@ func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFa
stz: stz,
nekoAuthClient: nekoAuthClient,
policy: &policy.Policy{},
captureSession: captureSession,
cdpMonitor: mon,
}, nil
}

Expand Down Expand Up @@ -313,5 +334,9 @@ func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequ
}

func (s *ApiService) Shutdown(ctx context.Context) error {
s.monitorMu.Lock()
s.cdpMonitor.Stop()
_ = s.captureSession.Close()
s.monitorMu.Unlock()
return s.recordManager.StopAll(ctx)
}
29 changes: 18 additions & 11 deletions server/cmd/api/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"log/slog"

"github.com/onkernel/kernel-images/server/lib/devtoolsproxy"
"github.com/onkernel/kernel-images/server/lib/events"
"github.com/onkernel/kernel-images/server/lib/nekoclient"
oapi "github.com/onkernel/kernel-images/server/lib/oapi"
"github.com/onkernel/kernel-images/server/lib/recorder"
Expand All @@ -25,7 +26,7 @@ func TestApiService_StartRecording(t *testing.T) {

t.Run("success", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)

resp, err := svc.StartRecording(ctx, oapi.StartRecordingRequestObject{})
Expand All @@ -39,7 +40,7 @@ func TestApiService_StartRecording(t *testing.T) {

t.Run("already recording", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)

// First start should succeed
Expand All @@ -54,7 +55,7 @@ func TestApiService_StartRecording(t *testing.T) {

t.Run("custom ids don't collide", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)

for i := 0; i < 5; i++ {
Expand Down Expand Up @@ -87,7 +88,7 @@ func TestApiService_StopRecording(t *testing.T) {

t.Run("no active recording", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)

resp, err := svc.StopRecording(ctx, oapi.StopRecordingRequestObject{})
Expand All @@ -100,7 +101,7 @@ func TestApiService_StopRecording(t *testing.T) {
rec := &mockRecorder{id: "default", isRecordingFlag: true}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)
resp, err := svc.StopRecording(ctx, oapi.StopRecordingRequestObject{})
require.NoError(t, err)
Expand All @@ -115,7 +116,7 @@ func TestApiService_StopRecording(t *testing.T) {

force := true
req := oapi.StopRecordingRequestObject{Body: &oapi.StopRecordingJSONRequestBody{ForceStop: &force}}
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)
resp, err := svc.StopRecording(ctx, req)
require.NoError(t, err)
Expand All @@ -129,7 +130,7 @@ func TestApiService_DownloadRecording(t *testing.T) {

t.Run("not found", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)
resp, err := svc.DownloadRecording(ctx, oapi.DownloadRecordingRequestObject{})
require.NoError(t, err)
Expand All @@ -149,7 +150,7 @@ func TestApiService_DownloadRecording(t *testing.T) {
rec := &mockRecorder{id: "default", isRecordingFlag: true, recordingData: randomBytes(minRecordingSizeInBytes - 1)}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)
// will return a 202 when the recording is too small
resp, err := svc.DownloadRecording(ctx, oapi.DownloadRecordingRequestObject{})
Expand Down Expand Up @@ -179,7 +180,7 @@ func TestApiService_DownloadRecording(t *testing.T) {
rec := &mockRecorder{id: "default", recordingData: data}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)
resp, err := svc.DownloadRecording(ctx, oapi.DownloadRecordingRequestObject{})
require.NoError(t, err)
Expand All @@ -199,7 +200,7 @@ func TestApiService_Shutdown(t *testing.T) {
rec := &mockRecorder{id: "default", isRecordingFlag: true}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)

require.NoError(t, svc.Shutdown(ctx))
Expand Down Expand Up @@ -303,10 +304,16 @@ func newMockNekoClient(t *testing.T) *nekoclient.AuthClient {
return client
}

func newCaptureSession() *events.CaptureSession {
ring := events.NewRingBuffer(64)
fw := events.NewFileWriter(os.TempDir())
return events.NewCaptureSession(ring, fw)
}

func TestApiService_PatchChromiumFlags(t *testing.T) {
ctx := context.Background()
mgr := recorder.NewFFmpegManager()
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, newMockFactory(), newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)

// Test with valid flags
Expand Down
2 changes: 1 addition & 1 deletion server/cmd/api/api/display_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func testFFmpegFactory(t *testing.T, tempDir string) recorder.FFmpegRecorderFact

func newTestServiceWithFactory(t *testing.T, mgr recorder.RecordManager, factory recorder.FFmpegRecorderFactory) *ApiService {
t.Helper()
svc, err := New(mgr, factory, newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t))
svc, err := New(mgr, factory, newTestUpstreamManager(), scaletozero.NewNoopController(), newMockNekoClient(t), newCaptureSession(), 0)
require.NoError(t, err)
return svc
}
Expand Down
35 changes: 35 additions & 0 deletions server/cmd/api/api/events.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package api

import (
"context"
"net/http"

"github.com/google/uuid"
"github.com/onkernel/kernel-images/server/lib/logger"
)

// StartCapture handles POST /events/start.
// Generates a new capture session ID, seeds the pipeline, then starts the
// CDP monitor. If already running, the monitor is stopped and
// restarted with a fresh session ID
func (s *ApiService) StartCapture(w http.ResponseWriter, r *http.Request) {
s.monitorMu.Lock()
defer s.monitorMu.Unlock()

s.captureSession.Start(uuid.New().String())

if err := s.cdpMonitor.Start(context.Background()); err != nil {
logger.FromContext(r.Context()).Error("failed to start CDP monitor", "err", err)
http.Error(w, "failed to start capture", http.StatusInternalServerError)
return
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Session ID changed before monitor restart causes misattribution

Medium Severity

captureSession.Start(newID) is called before cdpMonitor.Start(), which atomically changes the session ID while the old monitor may still be publishing events. Since captureSessionID is an atomic.Pointer readable without holding monitorMu, any events published by the old monitor's goroutines during the transition window will be incorrectly tagged with the new session ID. The session ID change needs to happen after the old monitor is fully stopped, not before.

Fix in Cursor Fix in Web

w.WriteHeader(http.StatusOK)
}

// StopCapture handles POST /events/stop
func (s *ApiService) StopCapture(w http.ResponseWriter, r *http.Request) {
s.monitorMu.Lock()
defer s.monitorMu.Unlock()
s.cdpMonitor.Stop()
w.WriteHeader(http.StatusOK)
}
12 changes: 12 additions & 0 deletions server/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/onkernel/kernel-images/server/cmd/config"
"github.com/onkernel/kernel-images/server/lib/chromedriverproxy"
"github.com/onkernel/kernel-images/server/lib/devtoolsproxy"
"github.com/onkernel/kernel-images/server/lib/events"
"github.com/onkernel/kernel-images/server/lib/logger"
"github.com/onkernel/kernel-images/server/lib/nekoclient"
oapi "github.com/onkernel/kernel-images/server/lib/oapi"
Expand Down Expand Up @@ -90,12 +91,19 @@ func main() {
os.Exit(1)
}

// Construct events pipeline
eventsRing := events.NewRingBuffer(1024)
eventsFileWriter := events.NewFileWriter("/var/log")
captureSession := events.NewCaptureSession(eventsRing, eventsFileWriter)

apiService, err := api.New(
recorder.NewFFmpegManager(),
recorder.NewFFmpegRecorderFactory(config.PathToFFmpeg, defaultParams, stz),
upstreamMgr,
stz,
nekoAuthClient,
captureSession,
config.DisplayNum,
)
if err != nil {
slogger.Error("failed to create api service", "err", err)
Expand All @@ -120,6 +128,10 @@ func main() {
w.Header().Set("Content-Type", "application/json")
w.Write(jsonData)
})
// capture events
r.Post("/events/start", apiService.StartCapture)
r.Post("/events/stop", apiService.StopCapture)

// PTY attach endpoint (WebSocket) - not part of OpenAPI spec
// Uses WebSocket for bidirectional streaming, which works well through proxies.
r.Get("/process/{process_id}/attach", func(w http.ResponseWriter, r *http.Request) {
Expand Down
41 changes: 41 additions & 0 deletions server/lib/cdpmonitor/monitor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package cdpmonitor

import (
"context"
"sync/atomic"

"github.com/onkernel/kernel-images/server/lib/events"
)

// UpstreamProvider abstracts *devtoolsproxy.UpstreamManager for testability.
type UpstreamProvider interface {
Current() string
Subscribe() (<-chan string, func())
}

// PublishFunc publishes an Event to the pipeline.
type PublishFunc func(ev events.Event)

// Monitor manages a CDP WebSocket connection with auto-attach session fan-out.
// Single-use per capture session: call Start to begin, Stop to tear down.
type Monitor struct {
running atomic.Bool
}

// New creates a Monitor. displayNum is the X display for ffmpeg screenshots.
func New(_ UpstreamProvider, _ PublishFunc, _ int) *Monitor {
return &Monitor{}
}

// IsRunning reports whether the monitor is actively capturing.
func (m *Monitor) IsRunning() bool {
return m.running.Load()
}

// Start begins CDP capture. Restarts if already running.
func (m *Monitor) Start(_ context.Context) error {
return nil
}

// Stop tears down the monitor. Safe to call multiple times.
func (m *Monitor) Stop() {}
29 changes: 19 additions & 10 deletions server/lib/events/capturesession.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@ package events
import (
"log/slog"
"sync"
"sync/atomic"
"time"
)

// CaptureSession is a single-use write path that wraps events in envelopes and
// fans them out to a FileWriter (durable) and RingBuffer (in-memory). Publish
// concurrently; Close flushes the FileWriter.
// fans them out to a FileWriter (durable) and RingBuffer (in-memory). Call Start
// once with a capture session ID, then Publish concurrently. Close flushes the
// FileWriter; there is no restart or terminal event.
type CaptureSession struct {
mu sync.Mutex
ring *RingBuffer
files *FileWriter
seq uint64
captureSessionID string
seq atomic.Uint64
captureSessionID atomic.Pointer[string]
}

func NewCaptureSession(captureSessionID string, ring *RingBuffer, files *FileWriter) *CaptureSession {
return &CaptureSession{ring: ring, files: files, captureSessionID: captureSessionID}
func NewCaptureSession(ring *RingBuffer, files *FileWriter) *CaptureSession {
s := &CaptureSession{ring: ring, files: files}
empty := ""
s.captureSessionID.Store(&empty)
return s
}

// Start sets the capture session ID stamped on every subsequent envelope.
func (s *CaptureSession) Start(captureSessionID string) {
s.captureSessionID.Store(&captureSessionID)
}

// Publish wraps ev in an Envelope, truncates if needed, then writes to
Expand All @@ -28,16 +38,15 @@ func (s *CaptureSession) Publish(ev Event) {
defer s.mu.Unlock()

if ev.Ts == 0 {
ev.Ts = time.Now().UnixMicro()
ev.Ts = time.Now().UnixMilli()
}
if ev.DetailLevel == "" {
ev.DetailLevel = DetailStandard
}

s.seq++
env := Envelope{
CaptureSessionID: s.captureSessionID,
Seq: s.seq,
CaptureSessionID: *s.captureSessionID.Load(),
Seq: s.seq.Add(1),
Event: ev,
}
env, data := truncateIfNeeded(env)
Expand Down
Loading
Loading