Skip to content

Commit 3e8a00d

Browse files
committed
estimate conn
1 parent 90a3cdd commit 3e8a00d

9 files changed

Lines changed: 209 additions & 30 deletions

File tree

pkg/manager/memory/memory.go

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ const (
3232
)
3333

3434
type UsageSnapshot struct {
35-
Used uint64
36-
Limit uint64
37-
Usage float64
38-
UpdateTime time.Time
39-
Valid bool
35+
Used uint64
36+
Limit uint64
37+
Usage float64
38+
UpdateTime time.Time
39+
EstimatedConnBufferMem int64
40+
Valid bool
4041
}
4142

4243
// MemManager is a manager for memory usage.
@@ -55,6 +56,7 @@ type MemManager struct {
5556
snapshotExpire time.Duration // used for test
5657
memoryLimit uint64
5758
latestUsage atomic.Value
59+
connBufferMem atomic.Int64
5860
}
5961

6062
func NewMemManager(lg *zap.Logger, cfgGetter config.ConfigGetter) *MemManager {
@@ -141,11 +143,12 @@ func (m *MemManager) refreshUsage() (UsageSnapshot, error) {
141143
return UsageSnapshot{}, err
142144
}
143145
snapshot := UsageSnapshot{
144-
Used: used,
145-
Limit: m.memoryLimit,
146-
Usage: float64(used) / float64(m.memoryLimit),
147-
UpdateTime: time.Now(),
148-
Valid: true,
146+
Used: used,
147+
Limit: m.memoryLimit,
148+
Usage: float64(used) / float64(m.memoryLimit),
149+
UpdateTime: time.Now(),
150+
EstimatedConnBufferMem: m.connBufferMem.Load(),
151+
Valid: true,
149152
}
150153
m.latestUsage.Store(snapshot)
151154
return snapshot, nil
@@ -156,6 +159,45 @@ func (m *MemManager) LatestUsage() UsageSnapshot {
156159
return snapshot
157160
}
158161

162+
func (m *MemManager) UpdateConnBufferMemory(delta int64) {
163+
if m == nil || delta == 0 {
164+
return
165+
}
166+
for {
167+
current := m.connBufferMem.Load()
168+
next := current + delta
169+
if next < 0 {
170+
next = 0
171+
}
172+
if m.connBufferMem.CompareAndSwap(current, next) {
173+
return
174+
}
175+
}
176+
}
177+
178+
func (m *MemManager) adjustUsageByConnBuffer(snapshot UsageSnapshot) UsageSnapshot {
179+
current := m.connBufferMem.Load()
180+
delta := current - snapshot.EstimatedConnBufferMem
181+
snapshot.EstimatedConnBufferMem = current
182+
if delta == 0 {
183+
return snapshot
184+
}
185+
if delta > 0 {
186+
snapshot.Used += uint64(delta)
187+
} else {
188+
released := uint64(-delta)
189+
if released >= snapshot.Used {
190+
snapshot.Used = 0
191+
} else {
192+
snapshot.Used -= released
193+
}
194+
}
195+
if snapshot.Limit > 0 {
196+
snapshot.Usage = float64(snapshot.Used) / float64(snapshot.Limit)
197+
}
198+
return snapshot
199+
}
200+
159201
func (m *MemManager) ShouldRejectNewConn() (bool, UsageSnapshot, float64) {
160202
if m == nil || m.cfgGetter == nil {
161203
return false, UsageSnapshot{}, 0
@@ -172,6 +214,7 @@ func (m *MemManager) ShouldRejectNewConn() (bool, UsageSnapshot, float64) {
172214
if !snapshot.Valid || time.Since(snapshot.UpdateTime) > m.snapshotExpire {
173215
return false, snapshot, threshold
174216
}
217+
snapshot = m.adjustUsageByConnBuffer(snapshot)
175218
return snapshot.Usage >= threshold, snapshot, threshold
176219
}
177220

pkg/manager/memory/memory_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,45 @@ func TestShouldRejectNewConn(t *testing.T) {
123123
require.False(t, reject)
124124
require.Equal(t, 0.9, threshold)
125125
}
126+
127+
func TestShouldRejectNewConnTracksConnBufferMemory(t *testing.T) {
128+
oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal
129+
defer func() {
130+
memory.MemUsed = oldMemUsed
131+
memory.MemTotal = oldMemTotal
132+
}()
133+
134+
cfg := config.NewConfig()
135+
cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9
136+
cfgGetter := mockCfgGetter{cfg: cfg}
137+
memory.MemUsed = func() (uint64, error) {
138+
return 890, nil
139+
}
140+
memory.MemTotal = func() (uint64, error) {
141+
return 1000, nil
142+
}
143+
m := NewMemManager(zap.NewNop(), &cfgGetter)
144+
m.checkInterval = 50 * time.Millisecond
145+
m.snapshotExpire = time.Second
146+
m.Start(context.Background())
147+
defer m.Close()
148+
149+
require.Eventually(t, func() bool {
150+
reject, snapshot, threshold := m.ShouldRejectNewConn()
151+
return !reject && snapshot.Valid && threshold == 0.9 && snapshot.Used == 890
152+
}, time.Second, 10*time.Millisecond)
153+
154+
m.UpdateConnBufferMemory(20)
155+
reject, snapshot, threshold := m.ShouldRejectNewConn()
156+
require.True(t, reject)
157+
require.Equal(t, 0.9, threshold)
158+
require.Equal(t, uint64(910), snapshot.Used)
159+
require.InDelta(t, 0.91, snapshot.Usage, 0.0001)
160+
161+
m.UpdateConnBufferMemory(-20)
162+
reject, snapshot, threshold = m.ShouldRejectNewConn()
163+
require.False(t, reject)
164+
require.Equal(t, 0.9, threshold)
165+
require.Equal(t, uint64(890), snapshot.Used)
166+
require.InDelta(t, 0.89, snapshot.Usage, 0.0001)
167+
}

pkg/proxy/backend/backend_conn_mgr.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ type BCConfig struct {
9797
DialTimeout time.Duration
9898
ConnectTimeout time.Duration
9999
ConnBufferSize int
100+
ConnBufferTracker pnet.ConnBufferMemoryTracker
100101
ProxyProtocol bool
101102
RequireBackendTLS bool
102103
}
@@ -327,7 +328,11 @@ func (mgr *BackendConnManager) getBackendIO(ctx context.Context, cctx ConnContex
327328
// NOTE: should use DNS name as much as possible
328329
// Usually certs are signed with domain instead of IP addrs
329330
// And `RemoteAddr()` will return IP addr
330-
backendIO := pnet.PacketIO(pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr(addr, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn)))
331+
backendIO := pnet.PacketIO(pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize,
332+
pnet.WithRemoteAddr(addr, cn.RemoteAddr()),
333+
pnet.WithWrapError(ErrBackendConn),
334+
pnet.WithConnBufferMemoryTracker(mgr.config.ConnBufferTracker),
335+
))
331336
mgr.backendIO.Store(&backendIO)
332337
mgr.curBackend = backend
333338
mgr.setKeepAlive()
@@ -657,7 +662,11 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
657662
mgr.handshakeHandler.OnHandshake(mgr, (*backendInst).Addr(), rs.err, SrcBackendNetwork)
658663
return
659664
}
660-
newBackendIO := pnet.PacketIO(pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr((*backendInst).Addr(), cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn)))
665+
newBackendIO := pnet.PacketIO(pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize,
666+
pnet.WithRemoteAddr((*backendInst).Addr(), cn.RemoteAddr()),
667+
pnet.WithWrapError(ErrBackendConn),
668+
pnet.WithConnBufferMemoryTracker(mgr.config.ConnBufferTracker),
669+
))
661670

662671
if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil {
663672
rs.err = mgr.initSessionStates(newBackendIO, sessionStates)

pkg/proxy/backend/cmd_processor_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,8 @@ func TestNetworkError(t *testing.T) {
978978
}
979979
backendErrChecker := func(t *testing.T, ts *testSuite) {
980980
require.True(t, pnet.IsDisconnectError(ts.mp.err))
981-
require.True(t, pnet.IsDisconnectError(ts.mb.err))
981+
// The backend mock may finish writing the error packet before the proxy actively closes the backend side.
982+
require.True(t, ts.mb.err == nil || pnet.IsDisconnectError(ts.mb.err))
982983
}
983984
proxyErrChecker := func(t *testing.T, ts *testSuite) {
984985
require.True(t, pnet.IsDisconnectError(ts.mp.err))

pkg/proxy/client/client_conn.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *t
3333
if bcConfig.ProxyProtocol {
3434
opts = append(opts, pnet.WithProxy)
3535
}
36+
if bcConfig.ConnBufferTracker != nil {
37+
opts = append(opts, pnet.WithConnBufferMemoryTracker(bcConfig.ConnBufferTracker))
38+
}
3639
pkt := pnet.NewPacketIO(conn, logger, bcConfig.ConnBufferSize, opts...)
3740
return &ClientConnection{
3841
logger: logger,

pkg/proxy/net/packetio.go

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ const (
5252
DefaultConnBufferSize = 32 * 1024
5353
)
5454

55+
func normalizeConnBufferSize(bufferSize int) int {
56+
if bufferSize == 0 {
57+
return DefaultConnBufferSize
58+
}
59+
return bufferSize
60+
}
61+
62+
func estimateConnBufferMemory(bufferSize int) int64 {
63+
return int64(normalizeConnBufferSize(bufferSize) * 2)
64+
}
65+
5566
type rwStatus int
5667

5768
const (
@@ -116,9 +127,7 @@ func getPooledWriter(conn net.Conn, size int) *bufio.Writer {
116127
}
117128

118129
func newBasicReadWriter(conn net.Conn, bufferSize int) *basicReadWriter {
119-
if bufferSize == 0 {
120-
bufferSize = DefaultConnBufferSize
121-
}
130+
bufferSize = normalizeConnBufferSize(bufferSize)
122131
return &basicReadWriter{
123132
Conn: conn,
124133
ReadWriter: bufio.NewReadWriter(getPooledReader(conn, bufferSize), getPooledWriter(conn, bufferSize)),
@@ -274,26 +283,33 @@ type PacketIO interface {
274283

275284
// PacketIO is a helper to read and write sql and proxy protocol.
276285
type packetIO struct {
277-
lastKeepAlive config.KeepAlive
278-
rawConn net.Conn
279-
readWriter packetReadWriter
280-
limitReader io.LimitedReader // reuse memory to reduce allocation
281-
logger *zap.Logger
282-
remoteAddr net.Addr
283-
wrap error
284-
header [4]byte // reuse memory to reduce allocation
285-
readPacketLimit int
286-
inPackets uint64
287-
outPackets uint64
286+
lastKeepAlive config.KeepAlive
287+
rawConn net.Conn
288+
readWriter packetReadWriter
289+
limitReader io.LimitedReader // reuse memory to reduce allocation
290+
logger *zap.Logger
291+
remoteAddr net.Addr
292+
wrap error
293+
header [4]byte // reuse memory to reduce allocation
294+
readPacketLimit int
295+
inPackets uint64
296+
outPackets uint64
297+
connBufferEstimate int64
298+
connBufferTracker ConnBufferMemoryTracker
299+
connBufferTracked bool
300+
releaseConnBuffer sync.Once
288301
}
289302

290303
func NewPacketIO(conn net.Conn, lg *zap.Logger, bufferSize int, opts ...PacketIOption) *packetIO {
304+
bufferSize = normalizeConnBufferSize(bufferSize)
291305
p := &packetIO{
292-
rawConn: conn,
293-
logger: lg,
294-
readWriter: newBasicReadWriter(conn, bufferSize),
306+
rawConn: conn,
307+
logger: lg,
308+
readWriter: newBasicReadWriter(conn, bufferSize),
309+
connBufferEstimate: estimateConnBufferMemory(bufferSize),
295310
}
296311
p.ApplyOpts(opts...)
312+
p.trackConnBufferMemory()
297313
return p
298314
}
299315

@@ -303,6 +319,22 @@ func (p *packetIO) ApplyOpts(opts ...PacketIOption) {
303319
}
304320
}
305321

322+
func (p *packetIO) trackConnBufferMemory() {
323+
if p.connBufferTracked || p.connBufferTracker == nil || p.connBufferEstimate == 0 {
324+
return
325+
}
326+
p.connBufferTracker.UpdateConnBufferMemory(p.connBufferEstimate)
327+
p.connBufferTracked = true
328+
}
329+
330+
func (p *packetIO) releaseConnBufferMemory() {
331+
p.releaseConnBuffer.Do(func() {
332+
if p.connBufferTracked && p.connBufferTracker != nil && p.connBufferEstimate != 0 {
333+
p.connBufferTracker.UpdateConnBufferMemory(-p.connBufferEstimate)
334+
}
335+
})
336+
}
337+
306338
func (p *packetIO) wrapErr(err error) error {
307339
return errors.Wrap(err, p.wrap)
308340
}
@@ -553,6 +585,7 @@ func (p *packetIO) GracefulClose() error {
553585
}
554586

555587
func (p *packetIO) Close() error {
588+
defer p.releaseConnBufferMemory()
556589
var errs []error
557590
/*
558591
TODO: flush when we want to smoothly exit

pkg/proxy/net/packetio_options.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ import (
1111

1212
type PacketIOption = func(*packetIO)
1313

14+
type ConnBufferMemoryTracker interface {
15+
UpdateConnBufferMemory(delta int64)
16+
}
17+
1418
func WithProxy(pi *packetIO) {
1519
pi.EnableProxyServer()
1620
}
@@ -29,6 +33,13 @@ func WithReadPacketLimit(limit int) func(pi *packetIO) {
2933
}
3034
}
3135

36+
func WithConnBufferMemoryTracker(tracker ConnBufferMemoryTracker) func(pi *packetIO) {
37+
return func(pi *packetIO) {
38+
pi.connBufferTracker = tracker
39+
pi.trackConnBufferMemory()
40+
}
41+
}
42+
3243
// WithRemoteAddr
3344
var _ proxyprotocol.AddressWrapper = &originAddr{}
3445

pkg/proxy/net/packetio_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package net
66
import (
77
"encoding/binary"
88
"net"
9+
"sync/atomic"
910
"testing"
1011
"time"
1112

@@ -19,6 +20,14 @@ import (
1920
"github.com/stretchr/testify/require"
2021
)
2122

23+
type mockConnBufferTracker struct {
24+
bytes atomic.Int64
25+
}
26+
27+
func (t *mockConnBufferTracker) UpdateConnBufferMemory(delta int64) {
28+
t.bytes.Add(delta)
29+
}
30+
2231
func testPipeConn(t *testing.T, a func(*testing.T, *packetIO), b func(*testing.T, *packetIO), loop int) {
2332
lg, _ := logger.CreateLoggerForTest(t)
2433
testkit.TestPipeConn(t,
@@ -863,3 +872,25 @@ func TestPoolSizeMismatch(t *testing.T) {
863872
_ = p1.Close()
864873
_ = p2.Close()
865874
}
875+
876+
func TestPacketIOConnBufferTracking(t *testing.T) {
877+
lg, _ := logger.CreateLoggerForTest(t)
878+
tracker := &mockConnBufferTracker{}
879+
customSize := DefaultConnBufferSize * 2
880+
expected := int64(customSize * 2)
881+
882+
c1, c2 := net.Pipe()
883+
packetIO := NewPacketIO(c1, lg, customSize, WithConnBufferMemoryTracker(tracker))
884+
require.Equal(t, expected, tracker.bytes.Load())
885+
886+
require.NoError(t, packetIO.GracefulClose())
887+
require.Equal(t, expected, tracker.bytes.Load())
888+
889+
require.NoError(t, packetIO.Close())
890+
require.Zero(t, tracker.bytes.Load())
891+
892+
require.NoError(t, packetIO.Close())
893+
require.Zero(t, tracker.bytes.Load())
894+
895+
require.NoError(t, c2.Close())
896+
}

pkg/proxy/proxy.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) {
174174
return
175175
}
176176

177+
var connBufferTracker pnet.ConnBufferMemoryTracker
178+
if s.memUsage != nil {
179+
connBufferTracker, _ = s.memUsage.(pnet.ConnBufferMemoryTracker)
180+
}
181+
177182
tcpKeepAlive, logger, connID, clientConn := func() (bool, *zap.Logger, uint64, *client.ClientConnection) {
178183
s.mu.Lock()
179184
defer s.mu.Unlock()
@@ -197,6 +202,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) {
197202
HealthyKeepAlive: s.mu.healthyKeepAlive,
198203
UnhealthyKeepAlive: s.mu.unhealthyKeepAlive,
199204
ConnBufferSize: s.mu.connBufferSize,
205+
ConnBufferTracker: connBufferTracker,
200206
FromPublicEndpoints: s.fromPublicEndpoint,
201207
DialContext: func(ctx context.Context, backendInst router.BackendInst, addr string) (net.Conn, error) {
202208
if s.dialer != nil {

0 commit comments

Comments
 (0)