Skip to content

Commit 69d8f53

Browse files
committed
chore: Use only one lock
Spin up docker Signed-off-by: Javier Aliaga <javier@diagrid.io>
1 parent b203628 commit 69d8f53

File tree

6 files changed

+190
-69
lines changed

6 files changed

+190
-69
lines changed

bindings/sftp/client.go

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,28 @@ import (
1818
"fmt"
1919
"os"
2020
"sync"
21+
"sync/atomic"
2122

2223
sftpClient "github.com/pkg/sftp"
2324
"golang.org/x/crypto/ssh"
2425
)
2526

2627
type Client struct {
27-
sshClient *ssh.Client
28-
sftpClient *sftpClient.Client
29-
address string
30-
config *ssh.ClientConfig
31-
lock sync.RWMutex
32-
rLock sync.Mutex
28+
sshClient *ssh.Client
29+
sftpClient *sftpClient.Client
30+
address string
31+
config *ssh.ClientConfig
32+
lock sync.RWMutex
33+
reconnectionCount atomic.Uint64
34+
}
35+
36+
type sftpError struct {
37+
err error
38+
reconnectionCount uint64
39+
}
40+
41+
func (s sftpError) Error() string {
42+
return s.err.Error()
3343
}
3444

3545
func newClient(address string, config *ssh.ClientConfig) (*Client, error) {
@@ -66,12 +76,15 @@ func (c *Client) Close() error {
6676
func (c *Client) list(path string) ([]os.FileInfo, error) {
6777
var fi []os.FileInfo
6878

69-
fn := func() error {
79+
fn := func() *sftpError {
7080
var err error
7181
c.lock.RLock()
7282
defer c.lock.RUnlock()
7383
fi, err = c.sftpClient.ReadDir(path)
74-
return err
84+
if err != nil {
85+
return &sftpError{err: err, reconnectionCount: c.reconnectionCount.Load()}
86+
}
87+
return nil
7588
}
7689

7790
err := withReconnection(c, fn)
@@ -87,17 +100,17 @@ func (c *Client) create(path string) (*sftpClient.File, string, error) {
87100

88101
var file *sftpClient.File
89102

90-
createFn := func() error {
103+
createFn := func() *sftpError {
91104
c.lock.RLock()
92105
defer c.lock.RUnlock()
93106
cErr := c.sftpClient.MkdirAll(dir)
94107
if cErr != nil {
95-
return fmt.Errorf("sftp binding error: error create dir %s: %w", dir, cErr)
108+
return &sftpError{err: fmt.Errorf("sftp binding error: error create dir %s: %w", dir, cErr), reconnectionCount: c.reconnectionCount.Load()}
96109
}
97110

98111
file, cErr = c.sftpClient.Create(path)
99112
if cErr != nil {
100-
return fmt.Errorf("sftp binding error: error create file %s: %w", path, cErr)
113+
return &sftpError{err: fmt.Errorf("sftp binding error: error create file %s: %w", path, cErr), reconnectionCount: c.reconnectionCount.Load()}
101114
}
102115

103116
return nil
@@ -114,12 +127,16 @@ func (c *Client) create(path string) (*sftpClient.File, string, error) {
114127
func (c *Client) get(path string) (*sftpClient.File, error) {
115128
var f *sftpClient.File
116129

117-
fn := func() error {
130+
fn := func() *sftpError {
118131
var err error
119132
c.lock.RLock()
133+
120134
defer c.lock.RUnlock()
121135
f, err = c.sftpClient.Open(path)
122-
return err
136+
if err != nil {
137+
return &sftpError{err: err, reconnectionCount: c.reconnectionCount.Load()}
138+
}
139+
return nil
123140
}
124141

125142
err := withReconnection(c, fn)
@@ -131,12 +148,15 @@ func (c *Client) get(path string) (*sftpClient.File, error) {
131148
}
132149

133150
func (c *Client) delete(path string) error {
134-
fn := func() error {
151+
fn := func() *sftpError {
135152
var err error
136153
c.lock.RLock()
137154
defer c.lock.RUnlock()
138155
err = c.sftpClient.Remove(path)
139-
return err
156+
if err != nil {
157+
return &sftpError{err: err, reconnectionCount: c.reconnectionCount.Load()}
158+
}
159+
return nil
140160
}
141161

142162
err := withReconnection(c, fn)
@@ -157,7 +177,7 @@ func (c *Client) ping() error {
157177
return nil
158178
}
159179

160-
func withReconnection(c *Client, fn func() error) error {
180+
func withReconnection(c *Client, fn func() *sftpError) error {
161181
err := fn()
162182
if err == nil {
163183
return nil
@@ -167,7 +187,7 @@ func withReconnection(c *Client, fn func() error) error {
167187
return err
168188
}
169189

170-
rErr := doReconnect(c)
190+
rErr := doReconnect(c, err.reconnectionCount)
171191
if rErr != nil {
172192
return errors.Join(err, rErr)
173193
}
@@ -180,33 +200,9 @@ func withReconnection(c *Client, fn func() error) error {
180200
return nil
181201
}
182202

183-
// 1) c.rLock (sync.Mutex) — reconnect serialization:
184-
// - Ensures only one goroutine performs the reconnect sequence at a time
185-
// (ping/check, dial SSH, create SFTP client), preventing a thundering herd
186-
// of concurrent reconnect attempts.
187-
// - Does NOT protect day-to-day client usage; it only coordinates who
188-
// is allowed to perform a reconnect.
189-
//
190-
// 2) c.lock (sync.RWMutex) — data-plane safety and atomic swap:
191-
// - Guards reads/writes of the active client handles (sshClient, sftpClient).
192-
// - Regular operations hold RLock while using the clients.
193-
// - Reconnect performs a short critical section with Lock to atomically swap
194-
// the client pointers; old clients are closed after unlocking to keep the
195-
// critical section small and avoid blocking readers.
196-
//
197-
// Why not a single RWMutex?
198-
// - If we used only c.lock and held it while dialing/handshaking, all I/O would
199-
// be blocked for the entire network operation, increasing latency and risk of
200-
// contention. Worse, reconnects triggered while a caller holds RLock could
201-
// deadlock or starve the writer.
202-
// - Separating concerns allows: (a) fast, minimal swap under c.lock, and
203-
// (b) serialized reconnect work under c.rLock without blocking readers.
204-
func doReconnect(c *Client) error {
205-
c.rLock.Lock()
206-
defer c.rLock.Unlock()
207-
208-
err := c.ping()
209-
if !shouldReconnect(err) {
203+
func doReconnect(c *Client, reconnectionCount uint64) error {
204+
// No need to reconnect as it has been reconnected
205+
if reconnectionCount != c.reconnectionCount.Load() {
210206
return nil
211207
}
212208

@@ -223,10 +219,17 @@ func doReconnect(c *Client) error {
223219

224220
// Swap under short lock; close old clients after unlocking.
225221
c.lock.Lock()
226-
oldSftp := c.sftpClient
227-
oldSSH := c.sshClient
228-
c.sftpClient = newSftpClient
229-
c.sshClient = sshClient
222+
var oldSftp *sftpClient.Client
223+
var oldSSH *ssh.Client
224+
if reconnectionCount == c.reconnectionCount.Load() {
225+
oldSftp = c.sftpClient
226+
oldSSH = c.sshClient
227+
c.sftpClient = newSftpClient
228+
c.sshClient = sshClient
229+
c.reconnectionCount.Add(1)
230+
sshClient = nil
231+
newSftpClient = nil
232+
}
230233
c.lock.Unlock()
231234

232235
if oldSftp != nil {
@@ -236,6 +239,14 @@ func doReconnect(c *Client) error {
236239
_ = oldSSH.Close()
237240
}
238241

242+
if newSftpClient != nil {
243+
_ = newSftpClient.Close()
244+
}
245+
246+
if sshClient != nil {
247+
_ = sshClient.Close()
248+
}
249+
239250
return nil
240251
}
241252

bindings/sftp/docker-compose.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
services:
2+
sftp:
3+
image:
4+
atmoz/sftp
5+
environment:
6+
- SFTP_USERS=foo:pass:1001:1001:upload
7+
volumes:
8+
- ./upload:/home/foo/upload
9+
ports:
10+
- "2222:22"
11+

bindings/sftp/sftp_integration_test.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build integration_test
2+
13
/*
24
Copyright 2025 The Dapr Authors
35
Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,37 +19,51 @@ import (
1719
"encoding/json"
1820
"os"
1921
"testing"
22+
"time"
2023

2124
"github.com/stretchr/testify/assert"
2225
"github.com/stretchr/testify/require"
2326

2427
"github.com/dapr/components-contrib/bindings"
28+
"github.com/dapr/components-contrib/tests/certification/flow"
29+
"github.com/dapr/components-contrib/tests/certification/flow/dockercompose"
2530
sftp "github.com/dapr/components-contrib/tests/utils/sftpproxy"
2631
)
2732

28-
const ProxySftp = "0.0.0.0:2223"
29-
30-
var connectionStringEnvKey = "DAPR_TEST_SFTP_CONNSTRING"
33+
const (
34+
ProxySftp = "0.0.0.0:2223"
35+
ConnectionString = "0.0.0.0:2222"
36+
)
3137

32-
// Run docker from the file location as the upload folder is relative to the test
33-
// cd proxy
34-
// docker run --name sftp -v ./upload:/home/foo/upload -p 2222:22 -d atmoz/sftp foo:pass:1001
35-
// export DAPR_TEST_SFTP_CONNSTRING=localhost:2222
3638
func TestIntegrationCases(t *testing.T) {
37-
connectionString := os.Getenv(connectionStringEnvKey)
38-
if connectionString == "" {
39-
t.Skipf("sftp binding integration skipped. To enable this test, define the connection string using environment variable '%[1]s' (example 'export %[1]s=\"localhost:2222\")'", connectionStringEnvKey)
40-
}
41-
39+
setupSftp(t)
40+
//defer cleanUp()
41+
time.Sleep(1 * time.Second)
4242
t.Run("List operation", testListOperation)
4343
t.Run("Create operation", testCreateOperation)
4444
t.Run("Reconnections", testReconnect)
4545
}
4646

47+
func setupSftp(t *testing.T) func() {
48+
dc := dockercompose.New("sftp", "docker-compose.yaml")
49+
ctx := flow.Context{
50+
T: t,
51+
Context: t.Context(),
52+
Flow: nil,
53+
}
54+
err := dc.Up(ctx)
55+
56+
if err != nil {
57+
t.Fatal(err)
58+
}
59+
60+
return func() { dc.Down(ctx) }
61+
}
62+
4763
func testListOperation(t *testing.T) {
4864
proxy := &sftp.Proxy{
4965
ListenAddr: ProxySftp,
50-
UpstreamAddr: os.Getenv(connectionStringEnvKey),
66+
UpstreamAddr: ConnectionString,
5167
}
5268

5369
defer proxy.Close()
@@ -82,7 +98,7 @@ func testListOperation(t *testing.T) {
8298
func testCreateOperation(t *testing.T) {
8399
proxy := &sftp.Proxy{
84100
ListenAddr: ProxySftp,
85-
UpstreamAddr: os.Getenv(connectionStringEnvKey),
101+
UpstreamAddr: ConnectionString,
86102
}
87103
defer proxy.Close()
88104
go proxy.ListenAndServe()
@@ -114,7 +130,7 @@ func testCreateOperation(t *testing.T) {
114130
require.NoError(t, err)
115131
assert.NotNil(t, r.Data)
116132

117-
file, err := os.Stat("./proxy/upload/test.txt")
133+
file, err := os.Stat("./upload/test.txt")
118134
require.NoError(t, err)
119135
assert.Equal(t, "test.txt", file.Name())
120136
assert.EqualValues(t, 1, proxy.ReconnectionCount.Load())
@@ -123,7 +139,7 @@ func testCreateOperation(t *testing.T) {
123139
func testReconnect(t *testing.T) {
124140
proxy := &sftp.Proxy{
125141
ListenAddr: ProxySftp,
126-
UpstreamAddr: os.Getenv(connectionStringEnvKey),
142+
UpstreamAddr: ConnectionString,
127143
}
128144
defer proxy.Close()
129145
go proxy.ListenAndServe()

0 commit comments

Comments
 (0)