From b1337696cb92b285dce176cf22db570b2682b7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan-Luis=20de=20Sousa-Valadas=20Casta=C3=B1o?= Date: Thu, 19 Dec 2024 12:06:14 +0100 Subject: [PATCH] Multiple refactors on cplb.tcpproxy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1- Simplify tcpproxy by reomivng useless interfaces: The original tcpproxy allowed different types of routes and needed to do a bunch of interfacing for it to work. Since we only implement one kind of route and one kind of target, we remove all the interfaces and merge both structs into a unique struct. 2- Remove proxy.AddRoute: We only used it once and setRoutes can cover that use case 3- Lock tcpproxy.Proxy when modifying routes to make it thread safe. Prior to this we relied on the proxy being called only from one goroutine. Now it can be called concurrently, not that we expect to do that but a lock gives us extra safety. 4- Panic if tcpproxy.SetRoutes gets and empty route list. We now check this in cplb_unix.go. 5- Remove the route feeding goroutine for round robin, since we added a lock to make proxy.SetRoutes threadsafe we don't need that anymore and it can be made much simpler by adding a lock. Signed-off-by: Juan-Luis de Sousa-Valadas CastaƱo --- inttest/cplb-userspace/cplbuserspace_test.go | 7 +- pkg/component/controller/cplb/cplb_linux.go | 24 ++- .../controller/cplb/tcpproxy/tcpproxy.go | 176 ++++++------------ .../controller/cplb/tcpproxy/tcpproxy_test.go | 18 +- 4 files changed, 86 insertions(+), 139 deletions(-) diff --git a/inttest/cplb-userspace/cplbuserspace_test.go b/inttest/cplb-userspace/cplbuserspace_test.go index 1b2da563d107..349f6f659362 100644 --- a/inttest/cplb-userspace/cplbuserspace_test.go +++ b/inttest/cplb-userspace/cplbuserspace_test.go @@ -154,15 +154,16 @@ func (s *CPLBUserSpaceSuite) TestK0sGetsUp() { s.T().Log("Testing that the load balancer is actually balancing the load") // Other stuff may be querying the controller, running the HTTPS request 15 times // should be more than we need. + attempt := 0 signatures := make(map[string]int) url := url.URL{Scheme: "https", Host: net.JoinHostPort(lb, strconv.Itoa(6443))} - for range 15 { + for len(signatures) < 3 { signature, err := getServerCertSignature(ctx, url.String()) s.Require().NoError(err) signatures[signature] = 1 + attempt++ + s.Require().LessOrEqual(attempt, 15, "Failed to get a signature from all controllers") } - - s.Require().Len(signatures, 3, "Expected 3 different signatures, got %d", len(signatures)) } // getLBAddress returns the IP address of the controller 0 and it adds 100 to diff --git a/pkg/component/controller/cplb/cplb_linux.go b/pkg/component/controller/cplb/cplb_linux.go index f17a21e79e78..0496ffa08c48 100644 --- a/pkg/component/controller/cplb/cplb_linux.go +++ b/pkg/component/controller/cplb/cplb_linux.go @@ -87,7 +87,7 @@ func (k *Keepalived) Init(_ context.Context) error { } // Start generates the keepalived configuration and starts the keepalived process -func (k *Keepalived) Start(_ context.Context) error { +func (k *Keepalived) Start(ctx context.Context) error { if k.Config == nil || (len(k.Config.VRRPInstances) == 0 && len(k.Config.VirtualServers) == 0) { k.log.Warn("No VRRP instances or virtual servers defined, skipping keepalived start") return nil @@ -154,8 +154,7 @@ func (k *Keepalived) Start(_ context.Context) error { if len(k.Config.VirtualServers) > 0 { k.watchReconcilerUpdatesKeepalived() } else { - - if err := k.watchReconcilerUpdatesReverseProxy(); err != nil { + if err := k.watchReconcilerUpdatesReverseProxy(ctx); err != nil { k.log.WithError(err).Error("failed to watch reconciler updates") } } @@ -347,18 +346,23 @@ func (k *Keepalived) generateKeepalivedTemplate() error { return nil } -func (k *Keepalived) watchReconcilerUpdatesReverseProxy() error { +func (k *Keepalived) watchReconcilerUpdatesReverseProxy(ctx context.Context) error { k.proxy = tcpproxy.Proxy{} // We don't know how long until we get the first update, so initially we // forward everything to localhost - k.proxy.AddRoute(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), tcpproxy.To(fmt.Sprintf("127.0.0.1:%d", k.APIPort))) + k.proxy.SetRoutes(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), []tcpproxy.Route{tcpproxy.To(fmt.Sprintf("127.0.0.1:%d", k.APIPort))}) if err := k.proxy.Start(); err != nil { return fmt.Errorf("failed to start proxy: %w", err) } - fmt.Println("Waiting for updateCh") - <-k.updateCh + k.log.Info("Waiting for the first cplb-reconciler update") + + select { + case <-ctx.Done(): + return errors.New("context cancelled while starting the reverse proxy") + case <-k.updateCh: + } k.setProxyRoutes() // Do not create the iptables rules until we have the first update and the @@ -374,11 +378,15 @@ func (k *Keepalived) watchReconcilerUpdatesReverseProxy() error { } func (k *Keepalived) setProxyRoutes() { - routes := []tcpproxy.Target{} + routes := []tcpproxy.Route{} for _, addr := range k.reconciler.GetIPs() { routes = append(routes, tcpproxy.To(fmt.Sprintf("%s:%d", addr, k.APIPort))) } + if len(routes) == 0 { + k.log.Error("No API servers available, leave previous configuration") + return + } k.proxy.SetRoutes(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), routes) } diff --git a/pkg/component/controller/cplb/tcpproxy/tcpproxy.go b/pkg/component/controller/cplb/tcpproxy/tcpproxy.go index 23b4b7fc52bd..2bdbffe50490 100644 --- a/pkg/component/controller/cplb/tcpproxy/tcpproxy.go +++ b/pkg/component/controller/cplb/tcpproxy/tcpproxy.go @@ -27,9 +27,11 @@ import ( "errors" "fmt" "io" - "log" "net" + "sync" "time" + + "github.com/sirupsen/logrus" ) // Proxy is a proxy. Its zero value is a valid proxy that does @@ -38,16 +40,18 @@ import ( // The order that routes are added in matters; each is matched in the order // registered. type Proxy struct { + mux sync.RWMutex configs map[string]*config // ip:port => config lns []net.Listener donec chan struct{} // closed before err err error // any error from listening - routesChan chan route + connNumber int // connection number counter, used for round robin // ListenFunc optionally specifies an alternate listen // function. If nil, net.Dial is used. - // The provided net is always "tcp". + // The provided net is always "tcp". This is to match + // the signature of net.Listen. ListenFunc func(net, laddr string) (net.Listener, error) } @@ -56,22 +60,7 @@ type Matcher func(ctx context.Context, hostname string) bool // config contains the proxying state for one listener. type config struct { - routes []route -} - -// A route matches a connection to a target. -type route interface { - // match examines the initial bytes of a connection, looking for a - // match. If a match is found, match returns a non-nil Target to - // which the stream should be proxied. match returns nil if the - // connection doesn't match. - // - // match must not consume bytes from the given bufio.Reader, it - // can only Peek. - // - // If an sni or host header was parsed successfully, that will be - // returned as the second parameter. - match(*bufio.Reader) (Target, string) + routes []Route } func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { @@ -91,28 +80,7 @@ func (p *Proxy) configFor(ipPort string) *config { return p.configs[ipPort] } -func (p *Proxy) addRoute(ipPort string, r route) { - cfg := p.configFor(ipPort) - cfg.routes = append(cfg.routes, r) -} - -// AddRoute appends an always-matching route to the ipPort listener, -// directing any connection to dest. -// -// This is generally used as either the only rule (for simple TCP -// proxies), or as the final fallback rule for an ipPort. -// -// The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddRoute(ipPort string, dest Target) { - p.addRoute(ipPort, fixedTarget{dest}) -} - -func (p *Proxy) setRoutes(ipPort string, targets []Target) { - var routes []route - for _, target := range targets { - routes = append(routes, fixedTarget{target}) - } - +func (p *Proxy) setRoutes(ipPort string, routes []Route) { cfg := p.configFor(ipPort) cfg.routes = routes } @@ -122,19 +90,15 @@ func (p *Proxy) setRoutes(ipPort string, targets []Target) { // It's possible that the old routes are still used once after this // function is called. If an empty slice is passed, the routes are // preserved in order to avoid an infinite loop. -func (p *Proxy) SetRoutes(ipPort string, targets []Target) { +func (p *Proxy) SetRoutes(ipPort string, targets []Route) { + p.mux.Lock() + defer p.mux.Unlock() if len(targets) == 0 { - return + panic("SetRoutes with empty targets") } p.setRoutes(ipPort, targets) } -type fixedTarget struct { - t Target -} - -func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" } - // Run is calls Start, and then Wait. // // It blocks until there's an error. The return value is always @@ -183,7 +147,6 @@ func (p *Proxy) Start() error { return err } p.lns = append(p.lns, ln) - p.routesChan = make(chan route) go p.serveListener(errc, ln, config) } go p.awaitFirstError(errc) @@ -196,48 +159,35 @@ func (p *Proxy) awaitFirstError(errc <-chan error) { } func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) { - go p.roundRobin(cfg) for { c, err := ln.Accept() if err != nil { ret <- err return } - go p.serveConn(c) + go p.serveConn(c, cfg) } } // serveConn runs in its own goroutine and matches c against routes. // It returns whether it matched purely for testing. -func (p *Proxy) serveConn(c net.Conn) bool { +func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { br := bufio.NewReader(c) - for route := range p.routesChan { - if target, hostName := route.match(br); target != nil { - if n := br.Buffered(); n > 0 { - peeked, _ := br.Peek(br.Buffered()) - c = &Conn{ - HostName: hostName, - Peeked: peeked, - Conn: c, - } - } - target.HandleConn(c) - return true - } - } - // TODO: hook for this? - log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) - c.Close() - return false -} -// roundRobin writes to a channel the next route to use. -func (p *Proxy) roundRobin(cfg *config) { - for { - for _, route := range cfg.routes { - p.routesChan <- route + p.mux.RLock() + p.connNumber++ + route := cfg.routes[p.connNumber%(len(cfg.routes))] + p.mux.RUnlock() + + if n := br.Buffered(); n > 0 { + peeked, _ := br.Peek(br.Buffered()) + c = &Conn{ + Peeked: peeked, + Conn: c, } } + route.HandleConn(c) + return true } // Conn is an incoming connection that has had some bytes read from it @@ -276,29 +226,17 @@ func (c *Conn) Read(p []byte) (n int, err error) { return c.Conn.Read(p) } -// Target is what an incoming matched connection is sent to. -type Target interface { - // HandleConn is called when an incoming connection is - // matched. After the call to HandleConn, the tcpproxy - // package never touches the conn again. Implementations are - // responsible for closing the connection when needed. - // - // The concrete type of conn will be of type *Conn if any - // bytes have been consumed for the purposes of route - // matching. - HandleConn(net.Conn) -} - // To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}. -func To(addr string) *DialProxy { - return &DialProxy{Addr: addr} +func To(addr string) Route { + return Route{Addr: addr} } -// DialProxy implements Target by dialing a new connection to Addr +// Route is what an incoming connection is sent to. +// It handles them by dialing a new connection to Addr // and then proxying data back and forth. // -// The To func is a shorthand way of creating a DialProxy. -type DialProxy struct { +// The To func is a shorthand way of creating a Route. +type Route struct { // Addr is the TCP address to proxy to. Addr string @@ -366,29 +304,29 @@ func closeWrite(c net.Conn) { } // HandleConn implements the Target interface. -func (dp *DialProxy) HandleConn(src net.Conn) { +func (r *Route) HandleConn(src net.Conn) { ctx := context.Background() var cancel context.CancelFunc - if dp.DialTimeout >= 0 { - ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout()) + if r.DialTimeout >= 0 { + ctx, cancel = context.WithTimeout(ctx, r.dialTimeout()) } - dst, err := dp.dialContext()(ctx, "tcp", dp.Addr) + dst, err := r.dialContext()(ctx, "tcp", r.Addr) if cancel != nil { cancel() } if err != nil { - dp.onDialError()(src, err) + r.onDialError()(src, err) return } defer goCloseConn(dst) - if err = dp.sendProxyHeader(dst, src); err != nil { - dp.onDialError()(src, err) + if err = r.sendProxyHeader(dst, src); err != nil { + r.onDialError()(src, err) return } defer goCloseConn(src) - if ka := dp.keepAlivePeriod(); ka > 0 { + if ka := r.keepAlivePeriod(); ka > 0 { for _, c := range []net.Conn{src, dst} { if c, ok := tcpConn(c); ok { _ = c.SetKeepAlive(true) @@ -404,8 +342,8 @@ func (dp *DialProxy) HandleConn(src net.Conn) { <-errc } -func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { - switch dp.ProxyProtocolVersion { +func (r *Route) sendProxyHeader(w io.Writer, src net.Conn) error { + switch r.ProxyProtocolVersion { case 0: return nil case 1: @@ -429,7 +367,7 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { _, err := fmt.Fprintf(w, "PROXY %s %s %s %d %d\r\n", family, srcAddr.IP, dstAddr.IP, srcAddr.Port, dstAddr.Port) return err default: - return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion) + return fmt.Errorf("PROXY protocol version %d not supported", r.ProxyProtocolVersion) } } @@ -458,35 +396,35 @@ func proxyCopy(errc chan<- error, dst, src net.Conn) { errc <- err } -func (dp *DialProxy) keepAlivePeriod() time.Duration { - if dp.KeepAlivePeriod != 0 { - return dp.KeepAlivePeriod +func (r *Route) keepAlivePeriod() time.Duration { + if r.KeepAlivePeriod != 0 { + return r.KeepAlivePeriod } return time.Minute } -func (dp *DialProxy) dialTimeout() time.Duration { - if dp.DialTimeout > 0 { - return dp.DialTimeout +func (r *Route) dialTimeout() time.Duration { + if r.DialTimeout > 0 { + return r.DialTimeout } return 10 * time.Second } var defaultDialer = new(net.Dialer) -func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) { - if dp.DialContext != nil { - return dp.DialContext +func (r *Route) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) { + if r.DialContext != nil { + return r.DialContext } return defaultDialer.DialContext } -func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) { - if dp.OnDialError != nil { - return dp.OnDialError +func (r *Route) onDialError() func(src net.Conn, dstDialErr error) { + if r.OnDialError != nil { + return r.OnDialError } return func(src net.Conn, dstDialErr error) { - log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr) + logrus.WithFields(logrus.Fields{"component": "tcpproxy"}).Errorf("for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), r.Addr, dstDialErr) src.Close() } } diff --git a/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go b/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go index 777a5c95e4ed..ff9f4be6b557 100644 --- a/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go +++ b/pkg/component/controller/cplb/tcpproxy/tcpproxy_test.go @@ -75,7 +75,7 @@ func TestBufferedClose(t *testing.T) { defer back.Close() p := testProxy(t, front) - p.AddRoute(testFrontAddr, To(back.Addr().String())) + p.SetRoutes(testFrontAddr, []Route{To(back.Addr().String())}) if err := p.Start(); err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func TestProxyAlwaysMatch(t *testing.T) { defer back.Close() p := testProxy(t, front) - p.AddRoute(testFrontAddr, To(back.Addr().String())) + p.setRoutes(testFrontAddr, []Route{To(back.Addr().String())}) if err := p.Start(); err != nil { t.Fatal(err) } @@ -149,10 +149,10 @@ func TestProxyPROXYOut(t *testing.T) { defer back.Close() p := testProxy(t, front) - p.AddRoute(testFrontAddr, &DialProxy{ + p.SetRoutes(testFrontAddr, []Route{{ Addr: back.Addr().String(), ProxyProtocolVersion: 1, - }) + }}) if err := p.Start(); err != nil { t.Fatal(err) } @@ -185,7 +185,7 @@ func TestSetRoutes(t *testing.T) { var p Proxy ipPort := ":8080" - p.AddRoute(ipPort, To("127.0.0.2:8080")) + p.setRoutes(ipPort, []Route{To("127.0.0.2:8080")}) cfg := p.configFor(ipPort) expectedAddrsList := [][]string{ @@ -203,21 +203,21 @@ func TestSetRoutes(t *testing.T) { } } -func stringsToTargets(s []string) []Target { - targets := make([]Target, len(s)) +func stringsToTargets(s []string) []Route { + targets := make([]Route, len(s)) for i, v := range s { targets[i] = To(v) } return targets } -func equalRoutes(routes []route, expectedAddrs []string) bool { +func equalRoutes(routes []Route, expectedAddrs []string) bool { if len(routes) != len(expectedAddrs) { return false } for i := range routes { - addr := routes[i].(fixedTarget).t.(*DialProxy).Addr + addr := routes[i].Addr if addr != expectedAddrs[i] { return false }