@@ -345,8 +345,8 @@ type clientStream struct {
345
345
readErr error // sticky read error; owned by transportResponseBody.Read
346
346
347
347
reqBody io.ReadCloser
348
- reqBodyContentLength int64 // -1 means unknown
349
- reqBodyClosed bool // body has been closed; guarded by cc.mu
348
+ reqBodyContentLength int64 // -1 means unknown
349
+ reqBodyClosed chan struct {} // guarded by cc.mu; non-nil on Close, closed when done
350
350
351
351
// owned by writeRequest:
352
352
sentEndStream bool // sent an END_STREAM flag to the peer
@@ -376,46 +376,48 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error
376
376
}
377
377
378
378
func (cs * clientStream ) abortStream (err error ) {
379
- var reqBody io.ReadCloser
380
- defer func () {
381
- if reqBody != nil {
382
- reqBody .Close ()
383
- }
384
- }()
385
379
cs .cc .mu .Lock ()
386
380
defer cs .cc .mu .Unlock ()
387
- reqBody = cs .abortStreamLocked (err )
381
+ cs .abortStreamLocked (err )
388
382
}
389
383
390
- func (cs * clientStream ) abortStreamLocked (err error ) io. ReadCloser {
384
+ func (cs * clientStream ) abortStreamLocked (err error ) {
391
385
cs .abortOnce .Do (func () {
392
386
cs .abortErr = err
393
387
close (cs .abort )
394
388
})
395
- var reqBody io.ReadCloser
396
- if cs .reqBody != nil && ! cs .reqBodyClosed {
397
- cs .reqBodyClosed = true
398
- reqBody = cs .reqBody
389
+ if cs .reqBody != nil {
390
+ cs .closeReqBodyLocked ()
399
391
}
400
392
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
401
393
if cs .cc .cond != nil {
402
394
// Wake up writeRequestBody if it is waiting on flow control.
403
395
cs .cc .cond .Broadcast ()
404
396
}
405
- return reqBody
406
397
}
407
398
408
399
func (cs * clientStream ) abortRequestBodyWrite () {
409
400
cc := cs .cc
410
401
cc .mu .Lock ()
411
402
defer cc .mu .Unlock ()
412
- if cs .reqBody != nil && ! cs .reqBodyClosed {
413
- cs .reqBody .Close ()
414
- cs .reqBodyClosed = true
403
+ if cs .reqBody != nil && cs .reqBodyClosed == nil {
404
+ cs .closeReqBodyLocked ()
415
405
cc .cond .Broadcast ()
416
406
}
417
407
}
418
408
409
+ func (cs * clientStream ) closeReqBodyLocked () {
410
+ if cs .reqBodyClosed != nil {
411
+ return
412
+ }
413
+ cs .reqBodyClosed = make (chan struct {})
414
+ reqBodyClosed := cs .reqBodyClosed
415
+ go func () {
416
+ cs .reqBody .Close ()
417
+ close (reqBodyClosed )
418
+ }()
419
+ }
420
+
419
421
type stickyErrWriter struct {
420
422
conn net.Conn
421
423
timeout time.Duration
@@ -771,12 +773,6 @@ func (cc *ClientConn) SetDoNotReuse() {
771
773
}
772
774
773
775
func (cc * ClientConn ) setGoAway (f * GoAwayFrame ) {
774
- var reqBodiesToClose []io.ReadCloser
775
- defer func () {
776
- for _ , reqBody := range reqBodiesToClose {
777
- reqBody .Close ()
778
- }
779
- }()
780
776
cc .mu .Lock ()
781
777
defer cc .mu .Unlock ()
782
778
@@ -793,10 +789,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
793
789
last := f .LastStreamID
794
790
for streamID , cs := range cc .streams {
795
791
if streamID > last {
796
- reqBody := cs .abortStreamLocked (errClientConnGotGoAway )
797
- if reqBody != nil {
798
- reqBodiesToClose = append (reqBodiesToClose , reqBody )
799
- }
792
+ cs .abortStreamLocked (errClientConnGotGoAway )
800
793
}
801
794
}
802
795
}
@@ -1049,19 +1042,11 @@ func (cc *ClientConn) sendGoAway() error {
1049
1042
func (cc * ClientConn ) closeForError (err error ) {
1050
1043
cc .mu .Lock ()
1051
1044
cc .closed = true
1052
-
1053
- var reqBodiesToClose []io.ReadCloser
1054
1045
for _ , cs := range cc .streams {
1055
- reqBody := cs .abortStreamLocked (err )
1056
- if reqBody != nil {
1057
- reqBodiesToClose = append (reqBodiesToClose , reqBody )
1058
- }
1046
+ cs .abortStreamLocked (err )
1059
1047
}
1060
1048
cc .cond .Broadcast ()
1061
1049
cc .mu .Unlock ()
1062
- for _ , reqBody := range reqBodiesToClose {
1063
- reqBody .Close ()
1064
- }
1065
1050
cc .closeConn ()
1066
1051
}
1067
1052
@@ -1458,11 +1443,19 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
1458
1443
// and in multiple cases: server replies <=299 and >299
1459
1444
// while still writing request body
1460
1445
cc .mu .Lock ()
1446
+ mustCloseBody := false
1447
+ if cs .reqBody != nil && cs .reqBodyClosed == nil {
1448
+ mustCloseBody = true
1449
+ cs .reqBodyClosed = make (chan struct {})
1450
+ }
1461
1451
bodyClosed := cs .reqBodyClosed
1462
- cs .reqBodyClosed = true
1463
1452
cc .mu .Unlock ()
1464
- if ! bodyClosed && cs . reqBody != nil {
1453
+ if mustCloseBody {
1465
1454
cs .reqBody .Close ()
1455
+ close (bodyClosed )
1456
+ }
1457
+ if bodyClosed != nil {
1458
+ <- bodyClosed
1466
1459
}
1467
1460
1468
1461
if err != nil && cs .sentEndStream {
@@ -1642,7 +1635,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
1642
1635
}
1643
1636
if err != nil {
1644
1637
cc .mu .Lock ()
1645
- bodyClosed := cs .reqBodyClosed
1638
+ bodyClosed := cs .reqBodyClosed != nil
1646
1639
cc .mu .Unlock ()
1647
1640
switch {
1648
1641
case bodyClosed :
@@ -1737,7 +1730,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
1737
1730
if cc .closed {
1738
1731
return 0 , errClientConnClosed
1739
1732
}
1740
- if cs .reqBodyClosed {
1733
+ if cs .reqBodyClosed != nil {
1741
1734
return 0 , errStopReqBodyWrite
1742
1735
}
1743
1736
select {
@@ -2110,24 +2103,17 @@ func (rl *clientConnReadLoop) cleanup() {
2110
2103
}
2111
2104
cc .closed = true
2112
2105
2113
- var reqBodiesToClose []io.ReadCloser
2114
2106
for _ , cs := range cc .streams {
2115
2107
select {
2116
2108
case <- cs .peerClosed :
2117
2109
// The server closed the stream before closing the conn,
2118
2110
// so no need to interrupt it.
2119
2111
default :
2120
- reqBody := cs .abortStreamLocked (err )
2121
- if reqBody != nil {
2122
- reqBodiesToClose = append (reqBodiesToClose , reqBody )
2123
- }
2112
+ cs .abortStreamLocked (err )
2124
2113
}
2125
2114
}
2126
2115
cc .cond .Broadcast ()
2127
2116
cc .mu .Unlock ()
2128
- for _ , reqBody := range reqBodiesToClose {
2129
- reqBody .Close ()
2130
- }
2131
2117
}
2132
2118
2133
2119
// countReadFrameError calls Transport.CountError with a string
0 commit comments