@@ -286,7 +286,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
286286 c .Close (StatusProtocolError , "received invalid close payload" )
287287 return xerrors .Errorf ("received invalid close payload: %w" , err )
288288 }
289- c .writeClose (b , ce , false )
289+ c .writeClose (b , xerrors . Errorf ( "received close frame: %w" , ce ) )
290290 return c .closeErr
291291 default :
292292 panic (fmt .Sprintf ("websocket: unexpected control opcode: %#v" , h ))
@@ -644,38 +644,54 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
644644 case c .setWriteTimeout <- ctx :
645645 }
646646
647- writeErr := func (err error ) error {
648- select {
649- case <- c .closed :
650- return c .closeErr
651- case <- ctx .Done ():
652- err = ctx .Err ()
653- default :
654- }
655-
656- err = xerrors .Errorf ("failed to write %v frame: %w" , h .opcode , err )
657- // We need to release the lock first before closing the connection to ensure
658- // the lock can be acquired inside close to ensure no one can access c.bw.
659- c .releaseLock (c .writeFrameLock )
660- c .close (err )
647+ n , err := c .realWriteFrame (ctx , h , p )
648+ if err != nil {
649+ return n , err
650+ }
661651
662- return err
652+ // We already finished writing, no need to potentially brick the connection if
653+ // the context expires.
654+ select {
655+ case <- c .closed :
656+ return n , c .closeErr
657+ case c .setWriteTimeout <- context .Background ():
663658 }
664659
660+ return n , nil
661+ }
662+
663+ func (c * Conn ) realWriteFrame (ctx context.Context , h header , p []byte ) (n int , err error ){
664+ defer func () {
665+ if err != nil {
666+ select {
667+ case <- c .closed :
668+ err = c .closeErr
669+ case <- ctx .Done ():
670+ err = ctx .Err ()
671+ default :
672+ }
673+
674+ err = xerrors .Errorf ("failed to write %v frame: %w" , h .opcode , err )
675+ // We need to release the lock first before closing the connection to ensure
676+ // the lock can be acquired inside close to ensure no one can access c.bw.
677+ c .releaseLock (c .writeFrameLock )
678+ c .close (err )
679+ }
680+ }()
681+
665682 headerBytes := writeHeader (c .writeHeaderBuf , h )
666683 _ , err = c .bw .Write (headerBytes )
667684 if err != nil {
668- return 0 , writeErr ( err )
685+ return 0 , err
669686 }
670687
671- var n int
672688 if c .client {
673689 var keypos int
674690 for len (p ) > 0 {
675691 if c .bw .Available () == 0 {
676692 err = c .bw .Flush ()
677693 if err != nil {
678- return n , writeErr ( err )
694+ return n , err
679695 }
680696 }
681697
@@ -689,7 +705,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
689705
690706 n2 , err := c .bw .Write (p2 )
691707 if err != nil {
692- return n , writeErr ( err )
708+ return n , err
693709 }
694710
695711 keypos = fastXOR (h .maskKey , keypos , c .writeBuf [i :i + n2 ])
@@ -700,25 +716,17 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
700716 } else {
701717 n , err = c .bw .Write (p )
702718 if err != nil {
703- return n , writeErr ( err )
719+ return n , err
704720 }
705721 }
706722
707- if fin {
723+ if h . fin {
708724 err = c .bw .Flush ()
709725 if err != nil {
710- return n , writeErr ( err )
726+ return n , err
711727 }
712728 }
713729
714- // We already finished writing, no need to potentially brick the connection if
715- // the context expires.
716- select {
717- case <- c .closed :
718- return n , c .closeErr
719- case c .setWriteTimeout <- context .Background ():
720- }
721-
722730 return n , nil
723731}
724732
@@ -767,10 +775,19 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
767775 p , _ = ce .bytes ()
768776 }
769777
770- return c .writeClose (p , ce , true )
778+ err = c .writeClose (p , xerrors .Errorf ("sent close frame: %w" , ce ))
779+ if err != nil {
780+ return err
781+ }
782+
783+ if ! xerrors .Is (c .closeErr , ce ) {
784+ return c .closeErr
785+ }
786+
787+ return nil
771788}
772789
773- func (c * Conn ) writeClose (p []byte , cerr error , us bool ) error {
790+ func (c * Conn ) writeClose (p []byte , cerr error ) error {
774791 ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
775792 defer cancel ()
776793
@@ -780,16 +797,7 @@ func (c *Conn) writeClose(p []byte, cerr error, us bool) error {
780797 return err
781798 }
782799
783- if us {
784- cerr = xerrors .Errorf ("sent close frame: %w" , cerr )
785- } else {
786- cerr = xerrors .Errorf ("received close frame: %w" , cerr )
787- }
788-
789800 c .close (cerr )
790- if ! xerrors .Is (c .closeErr , cerr ) {
791- return c .closeErr
792- }
793801
794802 return nil
795803}
0 commit comments