@@ -18,18 +18,28 @@ import (
1818 "fmt"
1919 "os"
2020 "sync"
21+ "sync/atomic"
2122
2223 sftpClient "github.com/pkg/sftp"
2324 "golang.org/x/crypto/ssh"
2425)
2526
2627type Client struct {
27- sshClient * ssh.Client
28- sftpClient * sftpClient.Client
29- address string
30- config * ssh.ClientConfig
31- lock sync.RWMutex
32- rLock sync.Mutex
28+ sshClient * ssh.Client
29+ sftpClient * sftpClient.Client
30+ address string
31+ config * ssh.ClientConfig
32+ lock sync.RWMutex
33+ reconnectionCount atomic.Uint64
34+ }
35+
36+ type sftpError struct {
37+ err error
38+ reconnectionCount uint64
39+ }
40+
41+ func (s sftpError ) Error () string {
42+ return s .err .Error ()
3343}
3444
3545func newClient (address string , config * ssh.ClientConfig ) (* Client , error ) {
@@ -66,12 +76,15 @@ func (c *Client) Close() error {
6676func (c * Client ) list (path string ) ([]os.FileInfo , error ) {
6777 var fi []os.FileInfo
6878
69- fn := func () error {
79+ fn := func () * sftpError {
7080 var err error
7181 c .lock .RLock ()
7282 defer c .lock .RUnlock ()
7383 fi , err = c .sftpClient .ReadDir (path )
74- return err
84+ if err != nil {
85+ return & sftpError {err : err , reconnectionCount : c .reconnectionCount .Load ()}
86+ }
87+ return nil
7588 }
7689
7790 err := withReconnection (c , fn )
@@ -87,17 +100,17 @@ func (c *Client) create(path string) (*sftpClient.File, string, error) {
87100
88101 var file * sftpClient.File
89102
90- createFn := func () error {
103+ createFn := func () * sftpError {
91104 c .lock .RLock ()
92105 defer c .lock .RUnlock ()
93106 cErr := c .sftpClient .MkdirAll (dir )
94107 if cErr != nil {
95- return fmt .Errorf ("sftp binding error: error create dir %s: %w" , dir , cErr )
108+ return & sftpError { err : fmt .Errorf ("sftp binding error: error create dir %s: %w" , dir , cErr ), reconnectionCount : c . reconnectionCount . Load ()}
96109 }
97110
98111 file , cErr = c .sftpClient .Create (path )
99112 if cErr != nil {
100- return fmt .Errorf ("sftp binding error: error create file %s: %w" , path , cErr )
113+ return & sftpError { err : fmt .Errorf ("sftp binding error: error create file %s: %w" , path , cErr ), reconnectionCount : c . reconnectionCount . Load ()}
101114 }
102115
103116 return nil
@@ -114,12 +127,16 @@ func (c *Client) create(path string) (*sftpClient.File, string, error) {
114127func (c * Client ) get (path string ) (* sftpClient.File , error ) {
115128 var f * sftpClient.File
116129
117- fn := func () error {
130+ fn := func () * sftpError {
118131 var err error
119132 c .lock .RLock ()
133+
120134 defer c .lock .RUnlock ()
121135 f , err = c .sftpClient .Open (path )
122- return err
136+ if err != nil {
137+ return & sftpError {err : err , reconnectionCount : c .reconnectionCount .Load ()}
138+ }
139+ return nil
123140 }
124141
125142 err := withReconnection (c , fn )
@@ -131,12 +148,15 @@ func (c *Client) get(path string) (*sftpClient.File, error) {
131148}
132149
133150func (c * Client ) delete (path string ) error {
134- fn := func () error {
151+ fn := func () * sftpError {
135152 var err error
136153 c .lock .RLock ()
137154 defer c .lock .RUnlock ()
138155 err = c .sftpClient .Remove (path )
139- return err
156+ if err != nil {
157+ return & sftpError {err : err , reconnectionCount : c .reconnectionCount .Load ()}
158+ }
159+ return nil
140160 }
141161
142162 err := withReconnection (c , fn )
@@ -157,7 +177,7 @@ func (c *Client) ping() error {
157177 return nil
158178}
159179
160- func withReconnection (c * Client , fn func () error ) error {
180+ func withReconnection (c * Client , fn func () * sftpError ) error {
161181 err := fn ()
162182 if err == nil {
163183 return nil
@@ -167,7 +187,7 @@ func withReconnection(c *Client, fn func() error) error {
167187 return err
168188 }
169189
170- rErr := doReconnect (c )
190+ rErr := doReconnect (c , err . reconnectionCount )
171191 if rErr != nil {
172192 return errors .Join (err , rErr )
173193 }
@@ -180,33 +200,14 @@ func withReconnection(c *Client, fn func() error) error {
180200 return nil
181201}
182202
183- // 1) c.rLock (sync.Mutex) — reconnect serialization:
184- // - Ensures only one goroutine performs the reconnect sequence at a time
185- // (ping/check, dial SSH, create SFTP client), preventing a thundering herd
186- // of concurrent reconnect attempts.
187- // - Does NOT protect day-to-day client usage; it only coordinates who
188- // is allowed to perform a reconnect.
189- //
190- // 2) c.lock (sync.RWMutex) — data-plane safety and atomic swap:
191- // - Guards reads/writes of the active client handles (sshClient, sftpClient).
192- // - Regular operations hold RLock while using the clients.
193- // - Reconnect performs a short critical section with Lock to atomically swap
194- // the client pointers; old clients are closed after unlocking to keep the
195- // critical section small and avoid blocking readers.
196- //
197- // Why not a single RWMutex?
198- // - If we used only c.lock and held it while dialing/handshaking, all I/O would
199- // be blocked for the entire network operation, increasing latency and risk of
200- // contention. Worse, reconnects triggered while a caller holds RLock could
201- // deadlock or starve the writer.
202- // - Separating concerns allows: (a) fast, minimal swap under c.lock, and
203- // (b) serialized reconnect work under c.rLock without blocking readers.
204- func doReconnect (c * Client ) error {
205- c .rLock .Lock ()
206- defer c .rLock .Unlock ()
203+ func doReconnect (c * Client , reconnectionCount uint64 ) error {
204+ // No need to reconnect as it has been reconnected
205+ if reconnectionCount != c .reconnectionCount .Load () {
206+ return nil
207+ }
207208
208209 err := c .ping ()
209- if ! shouldReconnect ( err ) {
210+ if err == nil {
210211 return nil
211212 }
212213
@@ -222,11 +223,19 @@ func doReconnect(c *Client) error {
222223 }
223224
224225 // Swap under short lock; close old clients after unlocking.
226+ // Close new clients if not swapped
225227 c .lock .Lock ()
226- oldSftp := c .sftpClient
227- oldSSH := c .sshClient
228- c .sftpClient = newSftpClient
229- c .sshClient = sshClient
228+ var oldSftp * sftpClient.Client
229+ var oldSSH * ssh.Client
230+ if reconnectionCount == c .reconnectionCount .Load () {
231+ oldSftp = c .sftpClient
232+ oldSSH = c .sshClient
233+ c .sftpClient = newSftpClient
234+ c .sshClient = sshClient
235+ c .reconnectionCount .Add (1 )
236+ sshClient = nil
237+ newSftpClient = nil
238+ }
230239 c .lock .Unlock ()
231240
232241 if oldSftp != nil {
@@ -236,13 +245,21 @@ func doReconnect(c *Client) error {
236245 _ = oldSSH .Close ()
237246 }
238247
248+ if newSftpClient != nil {
249+ _ = newSftpClient .Close ()
250+ }
251+
252+ if sshClient != nil {
253+ _ = sshClient .Close ()
254+ }
255+
239256 return nil
240257}
241258
242259func newSSHClient (address string , config * ssh.ClientConfig ) (* ssh.Client , error ) {
243260 sshClient , err := ssh .Dial ("tcp" , address , config )
244261 if err != nil {
245- return nil , fmt .Errorf ("sftp binding error: error create ssh client : %w" , err )
262+ return nil , fmt .Errorf ("sftp binding error: error dialing ssh server : %w" , err )
246263 }
247264 return sshClient , nil
248265}
0 commit comments