diff --git a/core/app/service/auth.go b/core/app/service/auth.go index 077af7951920..9da432f85c08 100644 --- a/core/app/service/auth.go +++ b/core/app/service/auth.go @@ -17,6 +17,7 @@ import ( "github.com/1Panel-dev/1Panel/core/buserr" "github.com/1Panel-dev/1Panel/core/constant" "github.com/1Panel-dev/1Panel/core/global" + "github.com/1Panel-dev/1Panel/core/init/session/psession" "github.com/1Panel-dev/1Panel/core/utils/encrypt" "github.com/1Panel-dev/1Panel/core/utils/mfa" "github.com/1Panel-dev/1Panel/core/utils/passkey" @@ -143,15 +144,8 @@ func (u *AuthService) generateSession(c *gin.Context, name string) (*dto.UserLog return nil, err } - sessionUser, err := global.SESSION.Get(c) - if err != nil { - err := global.SESSION.Set(c, sessionUser, httpsSetting.Value == constant.StatusEnable, lifeTime) - if err != nil { - return nil, err - } - return &dto.UserLoginInfo{Name: name}, nil - } - if err := global.SESSION.Set(c, sessionUser, httpsSetting.Value == constant.StatusEnable, lifeTime); err != nil { + sessionUser := psession.SessionUser{Name: name} + if err := global.SESSION.SetFresh(c, sessionUser, httpsSetting.Value == constant.StatusEnable, lifeTime); err != nil { return nil, err } diff --git a/core/init/router/proxy.go b/core/init/router/proxy.go index 49fdd188311b..8ebc44189c2b 100644 --- a/core/init/router/proxy.go +++ b/core/init/router/proxy.go @@ -80,7 +80,10 @@ func checkSession(c *gin.Context) bool { if err != nil { return false } - _ = global.SESSION.Set(c, psession, ssl == constant.StatusEnable, lifeTime) + if _, err := global.SESSION.RefreshIfNeeded(c, psession, ssl == constant.StatusEnable, lifeTime); err != nil { + global.LOG.Warnf("proxy refresh session failed, path=%s, err=%v", c.Request.URL.Path, err) + return false + } return true } diff --git a/core/init/session/psession/psession.go b/core/init/session/psession/psession.go index 0e2b4479c945..987581de5fb4 100644 --- a/core/init/session/psession/psession.go +++ b/core/init/session/psession/psession.go @@ -1,20 +1,15 @@ package psession import ( - "encoding/json" + "crypto/rand" + "encoding/hex" "errors" - "log" - "os" + "sync" + "sync/atomic" "time" "github.com/1Panel-dev/1Panel/core/constant" "github.com/gin-gonic/gin" - "github.com/glebarez/sqlite" - "github.com/gorilla/securecookie" - "github.com/gorilla/sessions" - "github.com/wader/gormstore/v2" - "gorm.io/gorm" - "gorm.io/gorm/logger" ) type SessionUser struct { @@ -22,93 +17,232 @@ type SessionUser struct { Name string `json:"name"` } +type sessionItem struct { + User SessionUser + ExpiredAt time.Time +} + type PSession struct { - Store *gormstore.Store - db *gorm.DB -} - -func NewPSession(dbPath string) *PSession { - newLogger := logger.New( - log.New(os.Stdout, "\r\n", log.LstdFlags), - logger.Config{ - SlowThreshold: time.Second, - LogLevel: logger.Silent, - IgnoreRecordNotFoundError: true, - Colorful: false, - }, - ) - db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{ - DisableForeignKeyConstraintWhenMigrating: true, - Logger: newLogger, - }) - if err != nil { - panic(err) - } - sqlDB, dbError := db.DB() - if dbError != nil { - panic(dbError) - } - sqlDB.SetMaxOpenConns(4) - sqlDB.SetMaxIdleConns(1) - sqlDB.SetConnMaxIdleTime(15 * time.Minute) - sqlDB.SetConnMaxLifetime(time.Hour) - - store := gormstore.New(db, securecookie.GenerateRandomKey(32)) + mu sync.RWMutex + sessions map[string]sessionItem + cleanupCursor uint64 + lastFullCleanup time.Time +} + +func NewPSession(_ string) *PSession { return &PSession{ - Store: store, - db: db, + sessions: make(map[string]sessionItem), } } func (p *PSession) Get(c *gin.Context) (SessionUser, error) { var result SessionUser - session, err := p.Store.Get(c.Request, constant.SessionName) - if err != nil { - return result, err + + sessionID, err := c.Cookie(constant.SessionName) + if err != nil || sessionID == "" { + return result, errors.New("ErrSessionDataNotFound") } - data, ok := session.Values["user"] + + p.mu.RLock() + item, ok := p.sessions[sessionID] + p.mu.RUnlock() if !ok { return result, errors.New("ErrSessionDataNotFound") } - bytes, ok := data.([]byte) - if !ok { - return result, errors.New("ErrSessionDataFormat") + if !item.ExpiredAt.IsZero() && time.Now().After(item.ExpiredAt) { + p.mu.Lock() + delete(p.sessions, sessionID) + p.mu.Unlock() + return result, errors.New("ErrSessionDataNotFound") } - err = json.Unmarshal(bytes, &result) - return result, err + return item.User, nil } func (p *PSession) Set(c *gin.Context, user SessionUser, secure bool, ttlSeconds int) error { - session, err := p.Store.Get(c.Request, constant.SessionName) - if err != nil { - return err + return p.set(c, user, secure, ttlSeconds, false) +} + +func (p *PSession) SetFresh(c *gin.Context, user SessionUser, secure bool, ttlSeconds int) error { + return p.set(c, user, secure, ttlSeconds, true) +} + +func (p *PSession) set(c *gin.Context, user SessionUser, secure bool, ttlSeconds int, forceNew bool) error { + sessionID, err := c.Cookie(constant.SessionName) + if forceNew { + if err == nil && sessionID != "" { + p.mu.Lock() + delete(p.sessions, sessionID) + p.mu.Unlock() + } + sessionID = "" } - data, err := json.Marshal(user) - if err != nil { - return err + if err != nil || sessionID == "" { + sessionID, err = generateSessionID() + if err != nil { + return err + } } - session.Values["user"] = data - session.Options = &sessions.Options{ - Path: "/", - MaxAge: ttlSeconds, - HttpOnly: true, - Secure: secure, + + expiredAt := time.Now().Add(time.Duration(ttlSeconds) * time.Second) + p.mu.Lock() + p.sessions[sessionID] = sessionItem{ + User: user, + ExpiredAt: expiredAt, } - return p.Store.Save(c.Request, c.Writer, session) + p.mu.Unlock() + p.cleanupExpiredOnWrite() + + c.SetCookie(constant.SessionName, sessionID, ttlSeconds, "/", "", secure, true) + return nil } -func (p *PSession) Delete(c *gin.Context) error { - session, err := p.Store.Get(c.Request, constant.SessionName) - if err != nil { - return err +func (p *PSession) RefreshIfNeeded(c *gin.Context, user SessionUser, secure bool, ttlSeconds int) (bool, error) { + sessionID, err := c.Cookie(constant.SessionName) + if err != nil || sessionID == "" { + return false, p.Set(c, user, secure, ttlSeconds) + } + + now := time.Now() + window := refreshWindow(ttlSeconds) + + p.mu.RLock() + item, ok := p.sessions[sessionID] + p.mu.RUnlock() + if !ok { + return false, p.Set(c, user, secure, ttlSeconds) } + if !item.ExpiredAt.IsZero() && now.After(item.ExpiredAt) { + p.mu.Lock() + delete(p.sessions, sessionID) + p.mu.Unlock() + return false, errors.New("ErrSessionDataNotFound") + } + if item.ExpiredAt.Sub(now) > window { + return false, nil + } + return true, p.Set(c, user, secure, ttlSeconds) +} - session.Values = make(map[interface{}]interface{}) - session.Options.MaxAge = -1 - return p.Store.Save(c.Request, c.Writer, session) +func (p *PSession) Delete(c *gin.Context) error { + sessionID, err := c.Cookie(constant.SessionName) + if err == nil && sessionID != "" { + p.mu.Lock() + delete(p.sessions, sessionID) + p.mu.Unlock() + } + return nil } func (p *PSession) Clean() error { - p.db.Table("sessions").Where("1=1").Delete(nil) + p.mu.Lock() + p.sessions = make(map[string]sessionItem) + p.lastFullCleanup = time.Time{} + p.mu.Unlock() return nil } + +func generateSessionID() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func refreshWindow(ttlSeconds int) time.Duration { + if ttlSeconds <= 0 { + return 0 + } + windowSeconds := ttlSeconds / 10 + if windowSeconds < 60 { + windowSeconds = 60 + } + if windowSeconds > 300 { + windowSeconds = 300 + } + if windowSeconds >= ttlSeconds { + windowSeconds = ttlSeconds - 1 + } + if windowSeconds <= 0 { + windowSeconds = 1 + } + return time.Duration(windowSeconds) * time.Second +} + +func (p *PSession) cleanupExpiredOnWrite() { + const ( + sampleSize = 32 + fullCleanupThreshold = 1024 + fullCleanupMinInterval = time.Minute + ) + + now := time.Now() + + p.mu.RLock() + size := len(p.sessions) + lastFullCleanup := p.lastFullCleanup + p.mu.RUnlock() + + if size == 0 { + return + } + if size >= fullCleanupThreshold && now.Sub(lastFullCleanup) >= fullCleanupMinInterval { + p.cleanupExpiredAll(now) + return + } + p.cleanupExpiredSample(now, sampleSize) +} + +func (p *PSession) cleanupExpiredSample(now time.Time, limit int) { + if limit <= 0 { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + + total := len(p.sessions) + if total == 0 { + return + } + start := int(atomic.AddUint64(&p.cleanupCursor, uint64(limit)) % uint64(total)) + checked := 0 + + idx := 0 + for key, item := range p.sessions { + if idx < start { + idx++ + continue + } + if !item.ExpiredAt.IsZero() && now.After(item.ExpiredAt) { + delete(p.sessions, key) + } + checked++ + idx++ + if checked >= limit { + break + } + } + if checked < limit { + for key, item := range p.sessions { + if checked >= limit { + break + } + if !item.ExpiredAt.IsZero() && now.After(item.ExpiredAt) { + delete(p.sessions, key) + } + checked++ + } + } +} + +func (p *PSession) cleanupExpiredAll(now time.Time) { + p.mu.Lock() + for key, item := range p.sessions { + if !item.ExpiredAt.IsZero() && now.After(item.ExpiredAt) { + delete(p.sessions, key) + } + } + p.lastFullCleanup = now + p.mu.Unlock() +} diff --git a/core/init/session/session.go b/core/init/session/session.go index be040619d93c..a2e6e9f9c09c 100644 --- a/core/init/session/session.go +++ b/core/init/session/session.go @@ -1,13 +1,11 @@ package session import ( - "path" - "github.com/1Panel-dev/1Panel/core/global" "github.com/1Panel-dev/1Panel/core/init/session/psession" ) func Init() { - global.SESSION = psession.NewPSession(path.Join(global.CONF.Base.InstallDir, "1panel/db/session.db")) - global.LOG.Info("init session successfully") + global.SESSION = psession.NewPSession("") + global.LOG.Info("init in-memory session successfully") } diff --git a/core/middleware/session.go b/core/middleware/session.go index 47f642fa50b7..4edf612580f6 100644 --- a/core/middleware/session.go +++ b/core/middleware/session.go @@ -42,7 +42,16 @@ func SessionAuth() gin.HandlerFunc { global.LOG.Errorf("create operation record failed, err: %v", err) return } - _ = global.SESSION.Set(c, psession, ssl == constant.StatusEnable, lifeTime) + if _, err := global.SESSION.RefreshIfNeeded(c, psession, ssl == constant.StatusEnable, lifeTime); err != nil { + errItem := err.Error() + if errItem == "ErrSessionDataFormat" || errItem == "ErrSessionDataNotFound" { + helper.BadAuth(c, "ErrNotLogin", buserr.New(errItem)) + return + } + global.LOG.Warnf("refresh session failed, path=%s, err=%v", c.Request.URL.Path, err) + helper.BadAuth(c, "ErrNotLogin", err) + return + } c.Next() } }