From 150ecdff86fd11cc36f1a3766df42a24edf69206 Mon Sep 17 00:00:00 2001 From: Maxb Date: Wed, 13 Mar 2024 15:08:56 -0700 Subject: [PATCH] Add http3 CreateTCPStream test This includes a little bit of refactoring the docker container initialization and making a testutils.NetConWrapper for re-use. --- http3/client_test.go | 145 ++++++++++++++++++++++++++++---- internal/testutils/testutils.go | 29 +++++++ 2 files changed, 156 insertions(+), 18 deletions(-) diff --git a/http3/client_test.go b/http3/client_test.go index 0681d21..724ecaa 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -2,9 +2,15 @@ package http3 import ( "context" + "crypto/tls" + "crypto/x509" "fmt" + "io" + "log" "log/slog" "net" + "net/http" + "net/http/httptest" "os" "strings" "testing" @@ -15,13 +21,17 @@ import ( "github.com/stretchr/testify/require" tc "github.com/testcontainers/testcontainers-go/modules/compose" "github.com/testcontainers/testcontainers-go/wait" + "golang.org/x/net/http2" ) const h2oServiceName string = "h2o" -func TestSimpleClientRequest(t *testing.T) { +var logger *slog.Logger +var containerGateway string + +func TestMain(m *testing.M) { level := slog.LevelDebug - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: level, })) slog.SetDefault(logger) @@ -30,19 +40,22 @@ func TestSimpleClientRequest(t *testing.T) { identifier := tc.StackIdentifier("h2o_test") composeFile := fmt.Sprintf("%s/docker-compose.yml", testutils.RootDir()) compose, err := tc.NewDockerComposeWith(tc.WithStackFiles(composeFile), identifier) - require.NoError(t, err, "NewDockerComposeAPIWith()") + if err != nil { + log.Fatalf("error in NewDockerComposeAPIWith: %v", err) + } - t.Cleanup(func() { - require.NoError(t, - compose.Down( - context.Background(), - tc.RemoveOrphans(true), - tc.RemoveImagesLocal, - ), "compose.Down()") - }) + defer func() { + if err := compose.Down( + context.Background(), + tc.RemoveOrphans(true), + tc.RemoveImagesLocal, + ); err != nil { + log.Fatalf("error in compose.Down: %v", err) + } + }() ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) + defer cancel() stack := compose.WaitForService(h2oServiceName, // The h2o conf provides a /status endpoint listening on @@ -53,26 +66,34 @@ func TestSimpleClientRequest(t *testing.T) { WithStartupTimeout(10*time.Second), ) - err = stack.Up(ctx, tc.Wait(true)) - - require.NoError(t, err, "compose.Up()") + if err := stack.Up(ctx, tc.Wait(true)); err != nil { + log.Fatalf("error in compose.Up(): %v", err) + } container, err := stack.ServiceContainer(ctx, h2oServiceName) - require.NoError(t, err, "fetch ServiceContainer") + if err != nil { + log.Fatalf("error in stack.ServiceContainer: %v", err) + } logger.Info("compose up", "services", stack.Services(), "container", container) // Kind of awkward network info parsing here. // We need the container's gateway IP because that _should_ be the address the host can ListenUDP on where the container can access it. containerIPs, err := container.ContainerIPs(ctx) - require.NoError(t, err, "container.ContainerIPs") + if err != nil { + log.Fatalf("error in container.ContainerIPs: %v", err) + } containerIP := containerIPs[0] containerIPSplit := strings.Split(containerIP, ".") containerNet := strings.Join(containerIPSplit[:len(containerIPSplit)-1], ".") - containerGateway := fmt.Sprintf("%v.1", containerNet) + containerGateway = fmt.Sprintf("%v.1", containerNet) + + m.Run() +} +func TestCreateUDPStream(t *testing.T) { expectedRequest := "test udp request data\n" expectedResponse := "test udp response data\n" @@ -133,3 +154,91 @@ func TestSimpleClientRequest(t *testing.T) { assert.Equal(t, expectedResponse, string(buf[0:n]), "Should receive correct UDP response") logger.Info("got response", "buf", buf) } + +func TestCreateTCPStream(t *testing.T) { + // Start target HTTP/S server + expectedResponse := "test http response data" + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "", r.Header.Get("Proxy-Authorization"), "Should not pass the auth token to the end target server") + assert.Equal(t, http.MethodGet, r.Method, "Request to the end target server should be a GET") + fmt.Fprintf(w, expectedResponse) + })) + ts.EnableHTTP2 = true + // We want to listen on 0.0.0.0 because the proxy container will be on a different non-localhost network. + // In order to do that we have this kind of awkward hack borrowed from: + // https://stackoverflow.com/a/42218765/1787596 + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + log.Fatal(err) + } + + // Swap out the default test server listener with our custom one listening on 0.0.0.0 + ts.Listener.Close() + ts.Listener = l + ts.StartTLS() + defer ts.Close() + + log.Printf("Test server listening on: %v", ts.URL) + + urlSplit := strings.Split(ts.URL, ":") + port := urlSplit[len(urlSplit)-1] + + certDataFile := fmt.Sprintf("%s/testdata/h2o/server.crt", testutils.RootDir()) + certData, err := os.ReadFile(certDataFile) + require.NoError(t, err, "Reading certData") + + config := ClientConfig{ + ProxyAddr: "localhost:8444", + // The h2o server we're using doesn't require an actual token so this can be anything + AuthToken: "fake-token", + Logger: logger, + CertData: certData, + Insecure: true, + } + + c, err := NewClient(config) + require.NoError(t, err, "NewClient") + defer c.Close() + + // host.docker.internal is a docker specific host mapping for the h2o container that resolves to our localhost + dockerHostURL := fmt.Sprintf("%v:%v", containerGateway, port) + conn, err := c.CreateTCPStream(dockerHostURL) + require.NoError(t, err, "CreateTCPStream") + defer conn.Close() + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%v", dockerHostURL), nil) + require.NoError(t, err, "http.NewRequest") + + certpool := x509.NewCertPool() + certpool.AddCert(ts.Certificate()) + tlsDialWrapper := func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tlsConf := &tls.Config{ + RootCAs: certpool, + // It seems as though the httptest TLS server uses this arbitrarily as it's ServerName 🤷 + ServerName: "example.com", + NextProtos: ts.TLS.NextProtos, + } + tlsClient := tls.Client(&testutils.NetConnWrapper{ReadWriteCloser: conn}, tlsConf) + err = tlsClient.Handshake() + return tlsClient, err + } + + transport := &http2.Transport{ + DialTLS: tlsDialWrapper, + } + + httpClient := http.Client{ + Transport: transport, + } + response, err := httpClient.Do(req) + require.NoError(t, err, "httpClient.Do") + + defer response.Body.Close() + data, err := io.ReadAll(response.Body) + require.NoError(t, err, "io.ReadAll response body") + + log.Printf("got response: %v", response) + + assert.Equal(t, 200, response.StatusCode, "Should receive 200 response") + assert.Equal(t, expectedResponse, string(data), "Should receive expected body") +} diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go index cf092fb..5240d6b 100644 --- a/internal/testutils/testutils.go +++ b/internal/testutils/testutils.go @@ -1,10 +1,13 @@ package testutils import ( + "io" "log" + "net" "path" "path/filepath" "runtime" + "time" ) func RootDir() string { @@ -13,3 +16,29 @@ func RootDir() string { d := path.Join(path.Dir(path.Dir(b))) return filepath.Dir(d) } + +// We want an interface that can implement net.Conn so we need to add these methods +// But we do not expect them to be called during our tests +type NetConnWrapper struct { + io.ReadWriteCloser +} + +func (r *NetConnWrapper) LocalAddr() net.Addr { + return nil +} + +func (r *NetConnWrapper) RemoteAddr() net.Addr { + return nil +} + +func (r *NetConnWrapper) SetDeadline(t time.Time) error { + return nil +} + +func (r *NetConnWrapper) SetReadDeadline(t time.Time) error { + return nil +} + +func (r *NetConnWrapper) SetWriteDeadline(t time.Time) error { + return nil +}