Skip to content

Commit c3594d7

Browse files
committed
chore: Use only one lock
Spin up docker Signed-off-by: Javier Aliaga <javier@diagrid.io>
1 parent b203628 commit c3594d7

File tree

6 files changed

+254
-68
lines changed

6 files changed

+254
-68
lines changed

bindings/sftp/client.go

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,28 @@ import (
1818
"fmt"
1919
"os"
2020
"sync"
21+
"sync/atomic"
2122

2223
sftpClient "github.com/pkg/sftp"
2324
"golang.org/x/crypto/ssh"
2425
)
2526

2627
type Client struct {
27-
sshClient *ssh.Client
28-
sftpClient *sftpClient.Client
29-
address string
30-
config *ssh.ClientConfig
31-
lock sync.RWMutex
32-
rLock sync.Mutex
28+
sshClient *ssh.Client
29+
sftpClient *sftpClient.Client
30+
address string
31+
config *ssh.ClientConfig
32+
lock sync.RWMutex
33+
reconnectionCount atomic.Uint64
34+
}
35+
36+
type sftpError struct {
37+
err error
38+
reconnectionCount uint64
39+
}
40+
41+
func (s sftpError) Error() string {
42+
return s.err.Error()
3343
}
3444

3545
func newClient(address string, config *ssh.ClientConfig) (*Client, error) {
@@ -66,12 +76,15 @@ func (c *Client) Close() error {
6676
func (c *Client) list(path string) ([]os.FileInfo, error) {
6777
var fi []os.FileInfo
6878

69-
fn := func() error {
79+
fn := func() *sftpError {
7080
var err error
7181
c.lock.RLock()
7282
defer c.lock.RUnlock()
7383
fi, err = c.sftpClient.ReadDir(path)
74-
return err
84+
if err != nil {
85+
return &sftpError{err: err, reconnectionCount: c.reconnectionCount.Load()}
86+
}
87+
return nil
7588
}
7689

7790
err := withReconnection(c, fn)
@@ -87,17 +100,17 @@ func (c *Client) create(path string) (*sftpClient.File, string, error) {
87100

88101
var file *sftpClient.File
89102

90-
createFn := func() error {
103+
createFn := func() *sftpError {
91104
c.lock.RLock()
92105
defer c.lock.RUnlock()
93106
cErr := c.sftpClient.MkdirAll(dir)
94107
if cErr != nil {
95-
return fmt.Errorf("sftp binding error: error create dir %s: %w", dir, cErr)
108+
return &sftpError{err: fmt.Errorf("sftp binding error: error create dir %s: %w", dir, cErr), reconnectionCount: c.reconnectionCount.Load()}
96109
}
97110

98111
file, cErr = c.sftpClient.Create(path)
99112
if cErr != nil {
100-
return fmt.Errorf("sftp binding error: error create file %s: %w", path, cErr)
113+
return &sftpError{err: fmt.Errorf("sftp binding error: error create file %s: %w", path, cErr), reconnectionCount: c.reconnectionCount.Load()}
101114
}
102115

103116
return nil
@@ -114,12 +127,16 @@ func (c *Client) create(path string) (*sftpClient.File, string, error) {
114127
func (c *Client) get(path string) (*sftpClient.File, error) {
115128
var f *sftpClient.File
116129

117-
fn := func() error {
130+
fn := func() *sftpError {
118131
var err error
119132
c.lock.RLock()
133+
120134
defer c.lock.RUnlock()
121135
f, err = c.sftpClient.Open(path)
122-
return err
136+
if err != nil {
137+
return &sftpError{err: err, reconnectionCount: c.reconnectionCount.Load()}
138+
}
139+
return nil
123140
}
124141

125142
err := withReconnection(c, fn)
@@ -131,12 +148,15 @@ func (c *Client) get(path string) (*sftpClient.File, error) {
131148
}
132149

133150
func (c *Client) delete(path string) error {
134-
fn := func() error {
151+
fn := func() *sftpError {
135152
var err error
136153
c.lock.RLock()
137154
defer c.lock.RUnlock()
138155
err = c.sftpClient.Remove(path)
139-
return err
156+
if err != nil {
157+
return &sftpError{err: err, reconnectionCount: c.reconnectionCount.Load()}
158+
}
159+
return nil
140160
}
141161

142162
err := withReconnection(c, fn)
@@ -157,7 +177,7 @@ func (c *Client) ping() error {
157177
return nil
158178
}
159179

160-
func withReconnection(c *Client, fn func() error) error {
180+
func withReconnection(c *Client, fn func() *sftpError) error {
161181
err := fn()
162182
if err == nil {
163183
return nil
@@ -167,7 +187,7 @@ func withReconnection(c *Client, fn func() error) error {
167187
return err
168188
}
169189

170-
rErr := doReconnect(c)
190+
rErr := doReconnect(c, err.reconnectionCount)
171191
if rErr != nil {
172192
return errors.Join(err, rErr)
173193
}
@@ -180,33 +200,14 @@ func withReconnection(c *Client, fn func() error) error {
180200
return nil
181201
}
182202

183-
// 1) c.rLock (sync.Mutex) — reconnect serialization:
184-
// - Ensures only one goroutine performs the reconnect sequence at a time
185-
// (ping/check, dial SSH, create SFTP client), preventing a thundering herd
186-
// of concurrent reconnect attempts.
187-
// - Does NOT protect day-to-day client usage; it only coordinates who
188-
// is allowed to perform a reconnect.
189-
//
190-
// 2) c.lock (sync.RWMutex) — data-plane safety and atomic swap:
191-
// - Guards reads/writes of the active client handles (sshClient, sftpClient).
192-
// - Regular operations hold RLock while using the clients.
193-
// - Reconnect performs a short critical section with Lock to atomically swap
194-
// the client pointers; old clients are closed after unlocking to keep the
195-
// critical section small and avoid blocking readers.
196-
//
197-
// Why not a single RWMutex?
198-
// - If we used only c.lock and held it while dialing/handshaking, all I/O would
199-
// be blocked for the entire network operation, increasing latency and risk of
200-
// contention. Worse, reconnects triggered while a caller holds RLock could
201-
// deadlock or starve the writer.
202-
// - Separating concerns allows: (a) fast, minimal swap under c.lock, and
203-
// (b) serialized reconnect work under c.rLock without blocking readers.
204-
func doReconnect(c *Client) error {
205-
c.rLock.Lock()
206-
defer c.rLock.Unlock()
203+
func doReconnect(c *Client, reconnectionCount uint64) error {
204+
// No need to reconnect as it has been reconnected
205+
if reconnectionCount != c.reconnectionCount.Load() {
206+
return nil
207+
}
207208

208209
err := c.ping()
209-
if !shouldReconnect(err) {
210+
if err == nil {
210211
return nil
211212
}
212213

@@ -222,11 +223,19 @@ func doReconnect(c *Client) error {
222223
}
223224

224225
// Swap under short lock; close old clients after unlocking.
226+
// Close new clients if not swapped
225227
c.lock.Lock()
226-
oldSftp := c.sftpClient
227-
oldSSH := c.sshClient
228-
c.sftpClient = newSftpClient
229-
c.sshClient = sshClient
228+
var oldSftp *sftpClient.Client
229+
var oldSSH *ssh.Client
230+
if reconnectionCount == c.reconnectionCount.Load() {
231+
oldSftp = c.sftpClient
232+
oldSSH = c.sshClient
233+
c.sftpClient = newSftpClient
234+
c.sshClient = sshClient
235+
c.reconnectionCount.Add(1)
236+
sshClient = nil
237+
newSftpClient = nil
238+
}
230239
c.lock.Unlock()
231240

232241
if oldSftp != nil {
@@ -236,13 +245,21 @@ func doReconnect(c *Client) error {
236245
_ = oldSSH.Close()
237246
}
238247

248+
if newSftpClient != nil {
249+
_ = newSftpClient.Close()
250+
}
251+
252+
if sshClient != nil {
253+
_ = sshClient.Close()
254+
}
255+
239256
return nil
240257
}
241258

242259
func newSSHClient(address string, config *ssh.ClientConfig) (*ssh.Client, error) {
243260
sshClient, err := ssh.Dial("tcp", address, config)
244261
if err != nil {
245-
return nil, fmt.Errorf("sftp binding error: error create ssh client: %w", err)
262+
return nil, fmt.Errorf("sftp binding error: error dialing ssh server: %w", err)
246263
}
247264
return sshClient, nil
248265
}

bindings/sftp/docker-compose.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
services:
2+
sftp:
3+
image:
4+
atmoz/sftp
5+
environment:
6+
- SFTP_USERS=foo:pass:1001:1001:upload
7+
volumes:
8+
- ./upload:/home/foo/upload
9+
ports:
10+
- "2222:22"
11+

bindings/sftp/sftp_integration_test.go

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build integration_test
2+
13
/*
24
Copyright 2025 The Dapr Authors
35
Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,40 +16,57 @@ limitations under the License.
1416
package sftp
1517

1618
import (
19+
"context"
1720
"encoding/json"
21+
"math/rand"
1822
"os"
23+
"sync/atomic"
1924
"testing"
25+
"time"
2026

2127
"github.com/stretchr/testify/assert"
2228
"github.com/stretchr/testify/require"
2329

2430
"github.com/dapr/components-contrib/bindings"
31+
"github.com/dapr/components-contrib/tests/certification/flow"
32+
"github.com/dapr/components-contrib/tests/certification/flow/dockercompose"
2533
sftp "github.com/dapr/components-contrib/tests/utils/sftpproxy"
2634
)
2735

28-
const ProxySftp = "0.0.0.0:2223"
29-
30-
var connectionStringEnvKey = "DAPR_TEST_SFTP_CONNSTRING"
36+
const (
37+
ProxySftp = "0.0.0.0:2223"
38+
ConnectionString = "0.0.0.0:2222"
39+
)
3140

32-
// Run docker from the file location as the upload folder is relative to the test
33-
// cd proxy
34-
// docker run --name sftp -v ./upload:/home/foo/upload -p 2222:22 -d atmoz/sftp foo:pass:1001
35-
// export DAPR_TEST_SFTP_CONNSTRING=localhost:2222
3641
func TestIntegrationCases(t *testing.T) {
37-
connectionString := os.Getenv(connectionStringEnvKey)
38-
if connectionString == "" {
39-
t.Skipf("sftp binding integration skipped. To enable this test, define the connection string using environment variable '%[1]s' (example 'export %[1]s=\"localhost:2222\")'", connectionStringEnvKey)
40-
}
41-
42+
cleanUp := setupSftp(t)
43+
defer cleanUp()
44+
time.Sleep(1 * time.Second)
4245
t.Run("List operation", testListOperation)
4346
t.Run("Create operation", testCreateOperation)
4447
t.Run("Reconnections", testReconnect)
4548
}
4649

50+
func setupSftp(t *testing.T) func() {
51+
dc := dockercompose.New("sftp", "docker-compose.yaml")
52+
ctx := flow.Context{
53+
T: t,
54+
Context: t.Context(),
55+
Flow: nil,
56+
}
57+
err := dc.Up(ctx)
58+
59+
if err != nil {
60+
t.Fatal(err)
61+
}
62+
63+
return func() { dc.Down(ctx) }
64+
}
65+
4766
func testListOperation(t *testing.T) {
4867
proxy := &sftp.Proxy{
4968
ListenAddr: ProxySftp,
50-
UpstreamAddr: os.Getenv(connectionStringEnvKey),
69+
UpstreamAddr: ConnectionString,
5170
}
5271

5372
defer proxy.Close()
@@ -82,7 +101,7 @@ func testListOperation(t *testing.T) {
82101
func testCreateOperation(t *testing.T) {
83102
proxy := &sftp.Proxy{
84103
ListenAddr: ProxySftp,
85-
UpstreamAddr: os.Getenv(connectionStringEnvKey),
104+
UpstreamAddr: ConnectionString,
86105
}
87106
defer proxy.Close()
88107
go proxy.ListenAndServe()
@@ -114,7 +133,7 @@ func testCreateOperation(t *testing.T) {
114133
require.NoError(t, err)
115134
assert.NotNil(t, r.Data)
116135

117-
file, err := os.Stat("./proxy/upload/test.txt")
136+
file, err := os.Stat("./upload/test.txt")
118137
require.NoError(t, err)
119138
assert.Equal(t, "test.txt", file.Name())
120139
assert.EqualValues(t, 1, proxy.ReconnectionCount.Load())
@@ -123,7 +142,7 @@ func testCreateOperation(t *testing.T) {
123142
func testReconnect(t *testing.T) {
124143
proxy := &sftp.Proxy{
125144
ListenAddr: ProxySftp,
126-
UpstreamAddr: os.Getenv(connectionStringEnvKey),
145+
UpstreamAddr: ConnectionString,
127146
}
128147
defer proxy.Close()
129148
go proxy.ListenAndServe()
@@ -218,4 +237,60 @@ func testReconnect(t *testing.T) {
218237

219238
assert.EqualValues(t, numReconnects+1, proxy.ReconnectionCount.Load())
220239
})
240+
241+
t.Run("Parallel ops - reconnection", func(t *testing.T) {
242+
numReconnects := proxy.ReconnectionCount.Load()
243+
ctx, cancelFn := context.WithCancel(t.Context())
244+
opCount := atomic.Int32{}
245+
opFailed := atomic.Int32{}
246+
for range 10 {
247+
go func(ctx context.Context) {
248+
for {
249+
select {
250+
case <-ctx.Done():
251+
return
252+
case <-time.After(time.Duration(500*rand.Float32()) * time.Millisecond):
253+
opCount.Add(1)
254+
r, err := c.Invoke(t.Context(), &bindings.InvokeRequest{Operation: bindings.ListOperation})
255+
if err != nil {
256+
opFailed.Add(1)
257+
break
258+
}
259+
260+
assert.NotNil(t, r.Data)
261+
}
262+
}
263+
}(ctx)
264+
}
265+
266+
go func(ctx context.Context) {
267+
for {
268+
select {
269+
case <-ctx.Done():
270+
return
271+
case <-time.After(100 * time.Millisecond):
272+
_ = proxy.KillServerConn()
273+
}
274+
}
275+
276+
}(ctx)
277+
278+
time.Sleep(time.Second * 5)
279+
cancelFn()
280+
281+
totalOps := opCount.Load()
282+
failedOps := opFailed.Load()
283+
284+
// Calculate 5% tolerance
285+
tolerance := float64(totalOps) * 0.05
286+
287+
// Assert that failed operations are within 1% of total operations
288+
assert.InDelta(t, 0, failedOps, tolerance,
289+
"Expected less than 1%% of operations to fail. Total: %d, Failed: %d (%.2f%%)",
290+
totalOps, failedOps, (float64(failedOps)/float64(totalOps))*100)
291+
292+
expectedReconnects := numReconnects + 10
293+
currentReconnects := proxy.ReconnectionCount.Load()
294+
assert.InDelta(t, expectedReconnects, currentReconnects, 2.0, "Expected %d reconnections, got %d", expectedReconnects, currentReconnects)
295+
})
221296
}

0 commit comments

Comments
 (0)