@@ -38,10 +38,10 @@ type Conn struct {
3838 writeDataLock chan struct {}
3939 writeFrameLock chan struct {}
4040
41- readDataLock chan struct {}
42- readData chan header
43- readDone chan struct {}
44- readLoopDone chan struct {}
41+ readMsgLock chan struct {}
42+ readMsg chan header
43+ readMsgDone chan struct {}
44+ readFrameLock chan struct {}
4545
4646 setReadTimeout chan context.Context
4747 setWriteTimeout chan context.Context
@@ -90,17 +90,15 @@ func (c *Conn) close(err error) {
9090
9191 close (c .closed )
9292
93+ // This ensures every goroutine that interacts
94+ // with the conn closes before it can interact with the connection
95+ c .readFrameLock <- struct {}{}
96+ c .writeFrameLock <- struct {}{}
97+
9398 // See comment in dial.go
9499 if c .client {
95- go func () {
96- <- c .readLoopDone
97- // TODO this does not work if reader errors out so skip for now.
98- // c.readDataLock <- struct{}{}
99- // c.writeFrameLock <- struct{}{}
100- //
101- // returnBufioReader(c.br)
102- // returnBufioWriter(c.bw)
103- }()
100+ returnBufioReader (c .br )
101+ returnBufioWriter (c .bw )
104102 }
105103 })
106104}
@@ -119,10 +117,10 @@ func (c *Conn) init() {
119117 c .writeDataLock = make (chan struct {}, 1 )
120118 c .writeFrameLock = make (chan struct {}, 1 )
121119
122- c .readData = make (chan header )
123- c .readDone = make (chan struct {})
124- c .readDataLock = make (chan struct {}, 1 )
125- c .readLoopDone = make (chan struct {})
120+ c .readMsg = make (chan header )
121+ c .readMsgDone = make (chan struct {})
122+ c .readMsgLock = make (chan struct {}, 1 )
123+ c .readFrameLock = make (chan struct {}, 1 )
126124
127125 c .setReadTimeout = make (chan context.Context )
128126 c .setWriteTimeout = make (chan context.Context )
@@ -141,8 +139,8 @@ func (c *Conn) init() {
141139
142140// We never mask inside here because our mask key is always 0,0,0,0.
143141// See comment on secWebSocketKey.
144- func (c * Conn ) writeFrame (ctx context.Context , h header , p []byte ) error {
145- err : = c .acquireLock (ctx , c .writeFrameLock )
142+ func (c * Conn ) writeFrame (ctx context.Context , h header , p []byte ) ( err error ) {
143+ err = c .acquireLock (ctx , c .writeFrameLock )
146144 if err != nil {
147145 return err
148146 }
@@ -164,27 +162,33 @@ func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error {
164162 }
165163 }()
166164
165+ defer func () {
166+ if err != nil {
167+ // We need to always release the lock first before closing the connection to ensure
168+ // the lock can be acquired inside close.
169+ c .releaseLock (c .writeFrameLock )
170+ c .close (err )
171+ }
172+ }()
173+
167174 h .masked = c .client
168175 h .payloadLength = int64 (len (p ))
169176
170177 b2 := marshalHeader (h )
171178 _ , err = c .bw .Write (b2 )
172179 if err != nil {
173- c .close (xerrors .Errorf ("failed to write to connection: %w" , err ))
174- return c .closeErr
180+ return xerrors .Errorf ("failed to write to connection: %w" , err )
175181 }
176182 _ , err = c .bw .Write (p )
177183 if err != nil {
178- c .close (xerrors .Errorf ("failed to write to connection: %w" , err ))
179- return c .closeErr
184+ return xerrors .Errorf ("failed to write to connection: %w" , err )
180185
181186 }
182187
183188 if h .fin {
184189 err := c .bw .Flush ()
185190 if err != nil {
186- c .close (xerrors .Errorf ("failed to write to connection: %w" , err ))
187- return c .closeErr
191+ return xerrors .Errorf ("failed to write to connection: %w" , err )
188192 }
189193 }
190194
@@ -279,9 +283,9 @@ func (c *Conn) handleControl(h header) {
279283
280284func (c * Conn ) readTillData () (header , error ) {
281285 for {
282- h , err := readHeader ( c . br )
286+ h , err := c . readHeader ( )
283287 if err != nil {
284- return header {}, xerrors . Errorf ( "failed to read header: %w" , err )
288+ return header {}, err
285289 }
286290
287291 if h .rsv1 || h .rsv2 || h .rsv3 {
@@ -312,9 +316,22 @@ func (c *Conn) readTillData() (header, error) {
312316 }
313317}
314318
315- func (c * Conn ) readLoop () {
316- defer close (c .readLoopDone )
319+ func (c * Conn ) readHeader () (header , error ) {
320+ err := c .acquireLock (context .Background (), c .readFrameLock )
321+ if err != nil {
322+ return header {}, err
323+ }
324+ defer c .releaseLock (c .readFrameLock )
317325
326+ h , err := readHeader (c .br )
327+ if err != nil {
328+ return header {}, xerrors .Errorf ("failed to read header: %w" , err )
329+ }
330+
331+ return h , nil
332+ }
333+
334+ func (c * Conn ) readLoop () {
318335 for {
319336 h , err := c .readTillData ()
320337 if err != nil {
@@ -325,13 +342,13 @@ func (c *Conn) readLoop() {
325342 select {
326343 case <- c .closed :
327344 return
328- case c .readData <- h :
345+ case c .readMsg <- h :
329346 }
330347
331348 select {
332349 case <- c .closed :
333350 return
334- case <- c .readDone :
351+ case <- c .readMsgDone :
335352 }
336353 }
337354}
@@ -374,7 +391,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
374391 // Definitely worth seeing what popular browsers do later.
375392 p , err := ce .bytes ()
376393 if err != nil {
377- fmt .Fprintf (os .Stderr , "failed to marshal close frame: %v\n " , err )
394+ fmt .Fprintf (os .Stderr , "websocket: failed to marshal close frame: %v\n " , err )
378395 ce = CloseError {
379396 Code : StatusInternalError ,
380397 }
@@ -415,7 +432,11 @@ func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
415432}
416433
417434func (c * Conn ) releaseLock (lock chan struct {}) {
418- <- lock
435+ // Allow multiple releases.
436+ select {
437+ case <- lock :
438+ default :
439+ }
419440}
420441
421442func (c * Conn ) writeMessage (ctx context.Context , opcode opcode , p []byte ) error {
@@ -572,7 +593,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
572593}
573594
574595func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
575- err = c .acquireLock (ctx , c .readDataLock )
596+ err = c .acquireLock (ctx , c .readMsgLock )
576597 if err != nil {
577598 return 0 , nil , err
578599 }
@@ -582,7 +603,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
582603 return 0 , nil , c .closeErr
583604 case <- ctx .Done ():
584605 return 0 , nil , ctx .Err ()
585- case h := <- c .readData :
606+ case h := <- c .readMsg :
586607 if h .opcode == opContinuation {
587608 ce := CloseError {
588609 Code : StatusProtocolError ,
@@ -631,7 +652,7 @@ func (r *messageReader) read(p []byte) (int, error) {
631652 select {
632653 case <- r .c .closed :
633654 return 0 , r .c .closeErr
634- case h := <- r .c .readData :
655+ case h := <- r .c .readMsg :
635656 if h .opcode != opContinuation {
636657 ce := CloseError {
637658 Code : StatusProtocolError ,
@@ -654,7 +675,12 @@ func (r *messageReader) read(p []byte) (int, error) {
654675 case r .c .setReadTimeout <- r .ctx :
655676 }
656677
678+ err := r .c .acquireLock (r .ctx , r .c .readFrameLock )
679+ if err != nil {
680+ return 0 , err
681+ }
657682 n , err := io .ReadFull (r .c .br , p )
683+ r .c .releaseLock (r .c .readFrameLock )
658684
659685 select {
660686 case <- r .c .closed :
@@ -676,11 +702,11 @@ func (r *messageReader) read(p []byte) (int, error) {
676702 select {
677703 case <- r .c .closed :
678704 return n , r .c .closeErr
679- case r .c .readDone <- struct {}{}:
705+ case r .c .readMsgDone <- struct {}{}:
680706 }
681707 if r .h .fin {
682708 r .eofed = true
683- r .c .releaseLock (r .c .readDataLock )
709+ r .c .releaseLock (r .c .readMsgLock )
684710 return n , io .EOF
685711 }
686712 r .maskPos = 0
0 commit comments