Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions pkg/transport/session/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ type Session interface {
Type() SessionType
CreatedAt() time.Time
UpdatedAt() time.Time
Touch()

// Data and metadata methods
GetData() interface{}
Expand Down Expand Up @@ -200,8 +199,8 @@ func (m *Manager) AddSession(session Session) error {
return m.storage.Store(ctx, session)
}

// Get retrieves a session by ID. Returns (session, true) if found,
// and also updates its UpdatedAt timestamp.
// Get retrieves a session by ID. Returns (session, true) if found.
// For LocalStorage, the storage backend updates the session's last-access timestamp on Load.
func (m *Manager) Get(id string) (Session, bool) {
ctx, cancel := context.WithTimeout(context.Background(), defaultOperationTimeout)
defer cancel()
Expand All @@ -210,8 +209,6 @@ func (m *Manager) Get(id string) (Session, bool) {
if err != nil {
return nil, false
}
// Touch the session to update its timestamp
sess.Touch()
return sess, true
}

Expand Down
46 changes: 25 additions & 21 deletions pkg/transport/session/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,53 +121,57 @@ func TestDeleteSession(t *testing.T) {
assert.False(t, ok, "deleted session should not be found")
}

func TestGetUpdatesTimestamp(t *testing.T) {
func TestGetPreventsEviction(t *testing.T) {
t.Parallel()
oldTime := time.Now().Add(-1 * time.Minute)
oldTime := time.Now().Add(-2 * time.Hour)
factory := &stubFactory{fixedTime: oldTime}
ttl := 1 * time.Hour

m := NewManager(time.Hour, factory.New)
m := NewManager(ttl, factory.New)
defer m.Stop()

require.NoError(t, m.AddWithID(uuidTouchme))
s1, ok := m.Get(uuidTouchme)

// LocalStorage.Store() stamps lastAccessNano = time.Now(), so the entry is
// always fresh after AddWithID. Backdate it so the session looks expired and
// would be evicted if Get() did not refresh the timestamp.
ls := m.storage.(*LocalStorage)
val, ok := ls.sessions.Load(uuidTouchme)
require.True(t, ok, "entry must exist in storage before backdating")
val.(*localEntry).lastAccessNano.Store(oldTime.UnixNano())

// Get() refreshes the storage-level last-access time by swapping in a new entry.
_, ok = m.Get(uuidTouchme)
require.True(t, ok)
t0 := s1.UpdatedAt()

time.Sleep(10 * time.Millisecond)
s2, ok2 := m.Get(uuidTouchme)
require.True(t, ok2)
t1 := s2.UpdatedAt()
// Cleanup with a cutoff of "now minus ttl" should NOT evict the session
// because Get() just refreshed its last-access timestamp.
require.NoError(t, m.cleanupExpiredOnce())

assert.True(t, t1.After(t0), "UpdatedAt should update on repeated Get()")
_, stillPresent := m.Get(uuidTouchme)
assert.True(t, stillPresent, "session should survive cleanup after a recent Get()")
}
func TestCleanupExpired_ManualTrigger(t *testing.T) {
t.Parallel()

// Stub factory: all sessions start with UpdatedAt = `now`
now := time.Now()
factory := &stubFactory{fixedTime: now}
factory := &stubFactory{fixedTime: time.Now()}
ttl := 50 * time.Millisecond

m := NewManager(ttl, factory.New)
defer m.Stop()

require.NoError(t, m.AddWithID(uuidOld))

// Retrieve and expire session manually
sess, ok := m.Get(uuidOld)
require.True(t, ok)
ps := sess.(*ProxySession)
ps.updated = now.Add(-ttl * 2)
// Wait for the session's last-access time to become older than the TTL.
time.Sleep(ttl * 2)

// Run cleanup manually
// Run cleanup — the stale session should be evicted.
m.cleanupExpiredOnce()

// Now it should be gone
_, okOld := m.Get(uuidOld)
assert.False(t, okOld, "expired session should have been cleaned")

// Add fresh session and assert it remains after cleanup
// A freshly-added session must survive the next cleanup run.
require.NoError(t, m.AddWithID(uuidNew))
m.cleanupExpiredOnce()
_, okNew := m.Get(uuidNew)
Expand Down
9 changes: 4 additions & 5 deletions pkg/transport/session/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ type Storage interface {
// Load retrieves a session by ID from the storage backend.
// Returns ErrSessionNotFound if the session doesn't exist.
//
// Implementations may refresh the backend's eviction TTL on every Load (e.g. Redis
// GETEX) to prevent active sessions from expiring between reads, because Manager.Get
// calls Touch on the returned object but does not call Store. This TTL refresh is a
// backend-level eviction concern and is distinct from the session's application-level
// UpdatedAt timestamp, which Load must NOT update.
// Implementations should refresh their backend's eviction TTL on every Load to
// prevent active sessions from expiring between reads. For Redis, this is done via
// GETEX. For LocalStorage, Load updates a storage-owned last-access timestamp so
// that DeleteExpired does not evict sessions that are actively being accessed.
Load(ctx context.Context, id string) (Session, error)

// Delete removes a session from the storage backend.
Expand Down
93 changes: 66 additions & 27 deletions pkg/transport/session/storage_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,33 @@ import (
"io"
"log/slog"
"sync"
"sync/atomic"
"time"
)

// localEntry wraps a session with a storage-owned last-access timestamp.
// All eviction decisions in LocalStorage are based on this timestamp, not on
// any field carried by the Session itself. This ensures every session type gets
// correct TTL behaviour regardless of its own implementation details.
type localEntry struct {
session Session
lastAccessNano atomic.Int64
}

func newLocalEntry(session Session) *localEntry {
e := &localEntry{session: session}
e.lastAccessNano.Store(time.Now().UnixNano())
return e
}

func (e *localEntry) lastAccess() time.Time {
return time.Unix(0, e.lastAccessNano.Load())
}

// LocalStorage implements the Storage interface using an in-memory sync.Map.
// This is the default storage backend for single-instance deployments.
type LocalStorage struct {
sessions sync.Map
sessions sync.Map // map[string]*localEntry
}

// NewLocalStorage creates a new local in-memory storage backend.
Expand All @@ -33,11 +53,13 @@ func (s *LocalStorage) Store(_ context.Context, session Session) error {
return fmt.Errorf("cannot store session with empty ID")
}

s.sessions.Store(session.ID(), session)
s.sessions.Store(session.ID(), newLocalEntry(session))
return nil
}

// Load retrieves a session from local storage.
// Load retrieves a session from local storage and refreshes its last-access timestamp.
// The timestamp update happens inside LocalStorage so that eviction is correct for
// all session types, not just those that implement a Touch() method.
func (s *LocalStorage) Load(_ context.Context, id string) (Session, error) {
if id == "" {
return nil, fmt.Errorf("cannot load session with empty ID")
Expand All @@ -48,12 +70,24 @@ func (s *LocalStorage) Load(_ context.Context, id string) (Session, error) {
return nil, ErrSessionNotFound
}

session, ok := val.(Session)
entry, ok := val.(*localEntry)
if !ok {
return nil, fmt.Errorf("invalid session type in storage")
}

return session, nil
// Refresh last-access time by swapping in a new entry pointer. This is
// intentional: if we mutated lastAccessNano in-place, DeleteExpired could
// still evict the session via CompareAndDelete (it holds the same pointer).
// Swapping the pointer makes CompareAndDelete fail for any DeleteExpired
// goroutine that snapshotted the old pointer, preventing eviction of active
// sessions under concurrent load.
newEntry := newLocalEntry(entry.session)
s.sessions.CompareAndSwap(id, entry, newEntry)
// If CAS fails, another goroutine already replaced this entry (e.g. a
// concurrent Store or Load). Either way the map holds a fresh pointer, so
// DeleteExpired will not evict it incorrectly.

return entry.session, nil
}

// Delete removes a session from local storage.
Expand All @@ -66,14 +100,14 @@ func (s *LocalStorage) Delete(_ context.Context, id string) error {
return nil
}

// DeleteExpired removes all sessions that haven't been updated since the given time.
// DeleteExpired removes all sessions whose last-access time is before the given cutoff.
func (s *LocalStorage) DeleteExpired(ctx context.Context, before time.Time) error {
var toDelete []struct {
id string
session Session
id string
entry *localEntry
}

// First pass: collect expired sessions
// First pass: collect expired entries
s.sessions.Range(func(key, val any) bool {
// Check for context cancellation
select {
Expand All @@ -82,20 +116,20 @@ func (s *LocalStorage) DeleteExpired(ctx context.Context, before time.Time) erro
default:
}

if session, ok := val.(Session); ok {
if session.UpdatedAt().Before(before) {
if entry, ok := val.(*localEntry); ok {
if entry.lastAccess().Before(before) {
if id, ok := key.(string); ok {
toDelete = append(toDelete, struct {
id string
session Session
}{id, session})
id string
entry *localEntry
}{id, entry})
}
}
}
return true
})

// Second pass: close and delete expired sessions
// Second pass: close and delete expired entries
for _, item := range toDelete {
// Check for context cancellation before processing each session
select {
Expand All @@ -105,24 +139,24 @@ func (s *LocalStorage) DeleteExpired(ctx context.Context, before time.Time) erro
}

// Re-check expiration and use CompareAndDelete to handle race conditions:
// - Session may have been touched via Manager.Get().Touch() and is no longer expired
// - Session may have been replaced via Store/UpsertSession with a new object
// Only proceed if the stored value is still the same session object and still expired
if item.session.UpdatedAt().Before(before) {
// CompareAndDelete ensures we only delete if the value hasn't been replaced
if deleted := s.sessions.CompareAndDelete(item.id, item.session); deleted {
// - Entry may have been touched via LocalStorage.Load and is no longer expired
// - Entry may have been replaced via Store/UpsertSession with a new object
// Only proceed if the stored value is still the same entry and still expired.
if item.entry.lastAccess().Before(before) {
// CompareAndDelete ensures we only delete if the entry hasn't been replaced
if deleted := s.sessions.CompareAndDelete(item.id, item.entry); deleted {
// Successfully deleted - now close if implements io.Closer
if closer, ok := item.session.(io.Closer); ok {
if closer, ok := item.entry.session.(io.Closer); ok {
if err := closer.Close(); err != nil {
slog.Warn("failed to close session during cleanup",
"session_id", item.id,
"error", err)
}
}
}
// If CompareAndDelete returned false, the session was already replaced/deleted - skip it
// If CompareAndDelete returned false, the entry was already replaced/deleted - skip it
}
// If re-check shows session is no longer expired (was touched), skip it
// If re-check shows entry is no longer expired (was touched via Load), skip it
}

return nil
Expand Down Expand Up @@ -154,8 +188,13 @@ func (s *LocalStorage) Count() int {
return count
}

// Range iterates over all sessions in storage.
// This is a helper method not part of the Storage interface.
// Range iterates over all sessions in storage, passing the session (not the
// internal wrapper) to f. This is a helper method not part of the Storage interface.
func (s *LocalStorage) Range(f func(key, value interface{}) bool) {
s.sessions.Range(f)
s.sessions.Range(func(key, val interface{}) bool {
if entry, ok := val.(*localEntry); ok {
return f(key, entry.session)
}
return f(key, val)
})
}
Loading
Loading