diff --git a/Dockerfile b/Dockerfile index aa7547b..a6514f4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.21-alpine AS build-env +FROM golang:1.23-alpine AS build-env RUN apk update RUN apk add g++ git make iptables-dev libpcap-dev @@ -14,7 +14,7 @@ RUN make build # run container FROM alpine -RUN apk add iptables-dev libpcap-dev +RUN apk add iptables iptables-dev libpcap-dev WORKDIR /opt/glutton COPY --from=build-env /opt/glutton/bin/server /opt/glutton/bin/server diff --git a/Makefile b/Makefile index 5d84330..47eef93 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ run: build docker: docker build -t glutton . - docker run --cap-add=NET_ADMIN -it glutton + docker run --rm --cap-add=NET_ADMIN -it glutton test: go test -v ./... diff --git a/app/server.go b/app/server.go index 41f0002..cf245ea 100644 --- a/app/server.go +++ b/app/server.go @@ -59,7 +59,7 @@ func main() { } if err := gtn.Init(); err != nil { - log.Fatal(err) + log.Fatal("Failed to initialize Glutton:", err) } exitMtx := sync.RWMutex{} @@ -86,6 +86,6 @@ func main() { }() if err := gtn.Start(); err != nil { - log.Fatalf("server start error: %s", err) + log.Fatal("Failed to start Glutton server:", err) } } diff --git a/glutton.go b/glutton.go index c1cc657..7a54247 100644 --- a/glutton.go +++ b/glutton.go @@ -12,6 +12,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "github.com/mushorg/glutton/connection" @@ -146,7 +147,8 @@ func (g *Glutton) Init() error { return nil } -func (g *Glutton) udpListen() { +func (g *Glutton) udpListen(wg *sync.WaitGroup) { + defer wg.Done() buffer := make([]byte, 1024) for { n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(g.Server.udpListener, buffer) @@ -177,40 +179,25 @@ func (g *Glutton) udpListen() { } } -// Start the listener, this blocks for new connections -func (g *Glutton) Start() error { - quit := make(chan struct{}) // stop monitor on shutdown - defer func() { - quit <- struct{}{} - g.Shutdown() - }() - - g.startMonitor(quit) - - if err := setTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "tcp", uint32(g.Server.tcpPort), uint32(viper.GetInt("ports.ssh"))); err != nil { - return err - } - - if err := setTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "udp", uint32(g.Server.udpPort), uint32(viper.GetInt("ports.ssh"))); err != nil { - return err - } - - go g.udpListen() - +func (g *Glutton) tcpListen(wg *sync.WaitGroup) { + defer wg.Done() for { select { case <-g.ctx.Done(): - return nil + return default: } + conn, err := g.Server.tcpListener.Accept() if err != nil { - return err + g.Logger.Error("failed to accept connection", producer.ErrAttr(err)) + continue } rule, err := g.applyRulesOnConn(conn) if err != nil { - return fmt.Errorf("failed to apply rules: %w", err) + g.Logger.Error("failed to apply rules", producer.ErrAttr(err)) + continue } if rule == nil { rule = &rules.Rule{Target: "default"} @@ -218,7 +205,8 @@ func (g *Glutton) Start() error { md, err := g.connTable.RegisterConn(conn, rule) if err != nil { - return err + g.Logger.Error("failed to register connection", producer.ErrAttr(err)) + continue } g.Logger.Debug("new connection", slog.String("addr", conn.LocalAddr().String()), slog.String("handler", rule.Target)) @@ -238,6 +226,38 @@ func (g *Glutton) Start() error { } } +// Start the listener, this blocks for new connections +func (g *Glutton) Start() error { + quit := make(chan struct{}) // stop monitor on shutdown + defer func() { + quit <- struct{}{} + g.Shutdown() + }() + + g.startMonitor(quit) + + sshPort := viper.GetUint32("ports.ssh") + if err := setTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "tcp", uint32(g.Server.tcpPort), sshPort); err != nil { + return err + } + + if err := setTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "udp", uint32(g.Server.udpPort), sshPort); err != nil { + return err + } + + wg := &sync.WaitGroup{} + + wg.Add(1) + go g.udpListen(wg) + + wg.Add(1) + go g.tcpListen(wg) + + wg.Wait() + + return nil +} + func (g *Glutton) makeID() error { filePath := filepath.Join(viper.GetString("var-dir"), "glutton.id") if err := os.MkdirAll(viper.GetString("var-dir"), 0744); err != nil { diff --git a/server.go b/server.go index cc6f57c..f75a2cd 100644 --- a/server.go +++ b/server.go @@ -47,11 +47,9 @@ func (s *Server) Start() error { func (s *Server) Shutdown() error { var err error if s.tcpListener != nil { - println("closing tcp listener") err = s.tcpListener.Close() } if s.udpListener != nil { - println("closing udp listener") err = s.udpListener.Close() } return err